diff --git a/segmenttree/persistent_segtree.hpp b/segmenttree/persistent_segtree.hpp new file mode 100644 index 00000000..50020da1 --- /dev/null +++ b/segmenttree/persistent_segtree.hpp @@ -0,0 +1,147 @@ +#pragma once + +#include +#include +#include +#include + +template struct persistent_segtree { + static_assert(std::is_convertible_v>, + "op must work as S(S, S)"); + static_assert(std::is_convertible_v>, "e must work as S()"); + + struct Root { + int id; + }; + + explicit persistent_segtree(int n) : persistent_segtree(std::vector(n, e())) {} + explicit persistent_segtree(const std::vector &v) : _n(int(v.size())) { + size = std::bit_ceil((unsigned int)_n); + lg = std::countr_zero((unsigned int)size); + nodes.assign(2 * size, Node{e(), -1, -1}); + + for (int i = 0; i < _n; i++) nodes[size + i].val = v[i]; + for (int i = size - 1; i >= 1; i--) { + nodes[i].left = 2 * i; + nodes[i].right = 2 * i + 1; + nodes[i].val = op(nodes[2 * i].val, nodes[2 * i + 1].val); + } + } + + Root set(const Root &root, int p, const S &x) { + assert(0 <= p && p < _n); + + std::vector ids(lg + 1); + + ids[lg] = root.id; + for (int i = lg - 1; i >= 0; --i) { + const Node &par = nodes[ids[i + 1]]; + ids[i] = ((p >> i) & 1) ? par.right : par.left; + } + + int copy_cur = new_node(x, -1, -1); + + for (int i = 1; i <= lg; i++) { + const int par = ids[i], cur = ids[i - 1]; + const Node &par_node = nodes[par]; + const int left = par_node.left == cur ? copy_cur : par_node.left; + const int right = par_node.right == cur ? copy_cur : par_node.right; + copy_cur = new_node(op(nodes[left].val, nodes[right].val), left, right); + } + + return Root{copy_cur}; + } + + S get(const Root &root, int p) const { + assert(0 <= p && p < _n); + int i = root.id; + for (int bit = lg - 1; bit >= 0; --bit) { + i = ((p >> bit) & 1) ? nodes[i].right : nodes[i].left; + } + return nodes[i].val; + } + + S prod(const Root &root, int l, int r) const { + assert(0 <= l && l <= r && r <= _n); + auto rec = [&](auto &&self, int i, int lo, int hi) -> S { + if (r <= lo || hi <= l) return e(); + if (l <= lo && hi <= r) return nodes[i].val; + const int mid = (lo + hi) >> 1; + return op(self(self, nodes[i].left, lo, mid), self(self, nodes[i].right, mid, hi)); + }; + return rec(rec, root.id, 0, size); + } + + S all_prod(const Root &root) const { return nodes[root.id].val; } + + template int max_right(const Root &root, int l) const { + return max_right(root, l, [](S x) { return f(x); }); + } + template int max_right(const Root &root, int l, F f) const { + assert(0 <= l && l <= _n); + assert(f(e())); + if (l == _n) return _n; + S sm = e(); + auto rec = [&](auto &&self, int i, int lo, int hi) -> int { + if (hi <= l) return hi; + if (l <= lo) { + const S nxt = op(sm, nodes[i].val); + if (f(nxt)) { + sm = nxt; + return hi; + } + if (hi - lo == 1) return lo; + } + const int mid = (lo + hi) >> 1; + if (l < mid) { + const int left_res = self(self, nodes[i].left, lo, mid); + if (left_res < mid) return left_res; + } + return self(self, nodes[i].right, mid, hi); + }; + return std::min(rec(rec, root.id, 0, size), _n); + } + + template int min_left(const Root &root, int r) const { + return min_left(root, r, [](S x) { return f(x); }); + } + template int min_left(const Root &root, int r, F f) const { + assert(0 <= r && r <= _n); + assert(f(e())); + if (r == 0) return 0; + S sm = e(); + auto rec = [&](auto &&self, int i, int lo, int hi) -> int { + if (r <= lo) return lo; + if (hi <= r) { + const S nxt = op(nodes[i].val, sm); + if (f(nxt)) { + sm = nxt; + return lo; + } + if (hi - lo == 1) return hi; + } + const int mid = (lo + hi) >> 1; + if (mid < r) { + const int right_res = self(self, nodes[i].right, mid, hi); + if (mid < right_res) return right_res; + } + return self(self, nodes[i].left, lo, mid); + }; + return rec(rec, root.id, 0, size); + } + + Root get_root() const { return {1}; } + + struct Node { + S val; + int left, right; + }; + + int _n, size, lg; + std::vector nodes; + + int new_node(const S &val, int left, int right) { + nodes.push_back(Node{val, left, right}); + return int(nodes.size()) - 1; + } +}; diff --git a/segmenttree/persistent_segtree.md b/segmenttree/persistent_segtree.md new file mode 100644 index 00000000..335f8b93 --- /dev/null +++ b/segmenttree/persistent_segtree.md @@ -0,0 +1,51 @@ +--- +title: Persistent segtree (完全永続セグメント木) +documentation_of: ./persistent_segtree.hpp +--- + +完全永続版のセグメント木.各点更新のたびに新しい版の根を返し,過去の任意の版に対して 1 点更新や区間積クエリを行える.インターフェースは `atcoder::segtree` に近く,第一引数に更新のもととなる版の根を与える点が異なる. + +## 使用方法 + +```cpp +struct S { + unsigned long long sum; + int len; +}; +S op(S l, S r) { return {l.sum + r.sum, l.len + r.len}; } +S e() { return {0, 0}; } + +vector A(N, {0, 1}); +persistent_segtree seg(A); + +auto root0 = seg.get_root(); // 初期版 +auto root1 = seg.set(root0, idx, {x, 1}); // idx 番目を更新した新しい版 + +S x = seg.get(root0, idx); // root0 版の idx 番目の値 +S y = seg.prod(root1, l, r); // root1 版の [l, r) の積 +S z = seg.all_prod(root1); // root1 版の列全体の積 + +int i = seg.max_right(root1, l, [](S x) { return x.sum <= lim; }); +int j = seg.min_left(root1, r, [](S x) { return x.sum <= lim; }); +``` + +`max_right`, `min_left` の意味は `atcoder::segtree` と同じ.すなわち,`f(e()) = true` を満たす単調な述語 `f` に対して, + +- `max_right(root, l, f)` は `prod(root, l, r)` が `f` を満たすような最大の `r` を返す. +- `min_left(root, r, f)` は `prod(root, l, r)` が `f` を満たすような最小の `l` を返す. + +計算量は以下の通り. + +- 構築 $O(N)$ +- `set` $O(\log N)$ 時間,更新 1 回あたり追加ノード数 $O(\log N)$ +- `get`, `prod`, `max_right`, `min_left` $O(\log N)$ +- `all_prod` $O(1)$ + +## 問題例 + +- [The 1st Universal Cup. Stage 15: Hangzhou G. Game: Celeste - Problem - QOJ.ac](https://qoj.ac/contest/1221/problem/6400) +- [AtCoder Beginner Contest 453 G - Copy Query](https://atcoder.jp/contests/abc453/tasks/abc453_g) + +## Link + +- [永続セグメント木 - AtCoderInfo](https://info.atcoder.jp/entry/algorithm_lectures/persistent_segment_tree)