线段树
算法 | python | cpp | 数据结构 | 模板
2025年11月26日
只提供一种, ZKW树. 资料: 浅析线段树实现
Python
class ZKWTr:
def __init__(self, size: int):
self.n = size
self.log = [0] * (self.n + 1)
for i in range(2, self.n + 1):
self.log[i] = self.log[i // 2] + 1
self.tr = [0] * (2 * self.n + 2)
self.tag = [0] * (2 * self.n + 2)
self.size = [1] * (2 * self.n + 2)
def build(self, arr: list[int]):
self.tr[self.n + 1 : self.n + 1 + len(arr)] = arr
for i in range(self.n - 1, 0, -1):
self.tr[i] = self.tr[i << 1] + self.tr[i << 1 | 1]
self.size[i] = self.size[i << 1] + self.size[i << 1 | 1]
def push_up(self, u: int):
if not u:
return
self.tr[u] = self.tr[u << 1] + self.tr[u << 1 | 1]
def modify(self, u: int, val: int):
self.tr[u] += self.size[u] * val
self.tag[u] += val
def push_down(self, u):
if self.tag[u]:
self.modify(u << 1, self.tag[u])
self.modify(u << 1 | 1, self.tag[u])
self.tag[u] = 0
def query(self, left: int, right: int):
left += self.n
right += self.n + 1
for i in range(self.log[self.n] + 1, 0, -1):
self.push_down(left >> i)
self.push_down(right >> i)
rst = 0
while left < right:
if left & 1:
rst += self.tr[left]
left += 1
if right & 1:
right -= 1
rst += self.tr[right]
left >>= 1
right >>= 1
return rst
def update(self, left: int, right: int, val: int):
left += self.n
right += self.n + 1
for i in range(self.log[self.n] + 1, 0, -1):
self.push_down(left >> i)
self.push_down(right >> i)
u = v = 0
while left < right:
if left & 1:
u = left
self.modify(left, val)
left += 1
if right & 1:
v = right
right -= 1
self.modify(right, val)
while True:
u >>= 1
self.push_up(u)
if not (left == right and u):
break
while True:
v >>= 1
self.push_up(v)
if not (left == right and v):
break
left >>= 1
right >>= 1
CPP
#define l_ch u << 1
#define r_ch u << 1 | 1
template <typename intType> struct ZKWTr {
int n;
std::vector<intType> tr, tag;
std::vector<size_t> sz;
int log_n;
ZKWTr(size_t size) {
n = size;
log_n = log2(n);
tr.assign(2 * n + 2, 0);
tag.assign(2 * n + 2, 0);
sz.assign(2 * n + 2, 1);
for (int u = n - 1; u; u--) sz[u] = sz[l_ch] + sz[r_ch];
}
void build(std::vector<intType> vec) {
std::copy(vec.begin(), vec.end(), tr.begin() + n + 1);
for (int u = n - 1; u; u--) tr[u] = tr[l_ch] + tr[r_ch];
}
void push_up(int u) {
if (!u) return;
tr[u] = tr[l_ch] + tr[r_ch];
}
void modify(int u, intType val) {
tr[u] += sz[u] * val;
tag[u] += val;
}
void push_down(int u) {
if (tag[u]) {
modify(l_ch, tag[u]);
modify(r_ch, tag[u]);
tag[u] = 0;
}
}
intType query(int l, int r) {
l += n, r += n + 1;
for (int i = log_n + 1; i; i--) push_down(l >> i), push_down(r >> i);
intType rst = 0;
for (; l < r; l >>= 1, r >>= 1) {
if (l & 1) rst += tr[l++];
if (r & 1) rst += tr[--r];
}
return rst;
}
void update(int l, int r, intType val) {
l += n, r += n + 1;
for (int i = log_n + 1; i; i--) push_down(l >> i), push_down(r >> i);
for (int u = 0, v = 0; l < r; l >>= 1, r >>= 1) {
if (l & 1) u = l, modify(l++, val);
if (r & 1) v = r, modify(--r, val);
do push_up(u >>= 1);
while (l == r and u);
do push_up(v >>= 1);
while (l == r and v);
}
}
};