线段树

算法 | python | cpp | 数据结构 | 模板

2025年11月26日

只提供一种, ZKW树. 资料: 浅析线段树实现

Python

class ZKWTr:
    def __init__(self, size: int):
        self.n = size
        self.log = [0] * (self.n + 1)
        for i in range(2, self.n + 1):
            self.log[i] = self.log[i // 2] + 1
        self.tr = [0] * (2 * self.n + 2)
        self.tag = [0] * (2 * self.n + 2)
        self.size = [1] * (2 * self.n + 2)

    def build(self, arr: list[int]):
        self.tr[self.n + 1 : self.n + 1 + len(arr)] = arr
        for i in range(self.n - 1, 0, -1):
            self.tr[i] = self.tr[i << 1] + self.tr[i << 1 | 1]
            self.size[i] = self.size[i << 1] + self.size[i << 1 | 1]

    def push_up(self, u: int):
        if not u:
            return
        self.tr[u] = self.tr[u << 1] + self.tr[u << 1 | 1]

    def modify(self, u: int, val: int):
        self.tr[u] += self.size[u] * val
        self.tag[u] += val

    def push_down(self, u):
        if self.tag[u]:
            self.modify(u << 1, self.tag[u])
            self.modify(u << 1 | 1, self.tag[u])
            self.tag[u] = 0

    def query(self, left: int, right: int):
        left += self.n
        right += self.n + 1
        for i in range(self.log[self.n] + 1, 0, -1):
            self.push_down(left >> i)
            self.push_down(right >> i)
        rst = 0
        while left < right:
            if left & 1:
                rst += self.tr[left]
                left += 1
            if right & 1:
                right -= 1
                rst += self.tr[right]
            left >>= 1
            right >>= 1
        return rst

    def update(self, left: int, right: int, val: int):
        left += self.n
        right += self.n + 1
        for i in range(self.log[self.n] + 1, 0, -1):
            self.push_down(left >> i)
            self.push_down(right >> i)
        u = v = 0
        while left < right:
            if left & 1:
                u = left
                self.modify(left, val)
                left += 1
            if right & 1:
                v = right
                right -= 1
                self.modify(right, val)
            while True:
                u >>= 1
                self.push_up(u)
                if not (left == right and u):
                    break
            while True:
                v >>= 1
                self.push_up(v)
                if not (left == right and v):
                    break
            left >>= 1
            right >>= 1

CPP

#define l_ch u << 1
#define r_ch u << 1 | 1
template <typename intType> struct ZKWTr {
    int n;
    std::vector<intType> tr, tag;
    std::vector<size_t> sz;
    int log_n;
    ZKWTr(size_t size) {
        n = size;
        log_n = log2(n);
        tr.assign(2 * n + 2, 0);
        tag.assign(2 * n + 2, 0);
        sz.assign(2 * n + 2, 1);
        for (int u = n - 1; u; u--) sz[u] = sz[l_ch] + sz[r_ch];
    }
    void build(std::vector<intType> vec) {
        std::copy(vec.begin(), vec.end(), tr.begin() + n + 1);
        for (int u = n - 1; u; u--) tr[u] = tr[l_ch] + tr[r_ch];
    }
    void push_up(int u) {
        if (!u) return;
        tr[u] = tr[l_ch] + tr[r_ch];
    }
    void modify(int u, intType val) {
        tr[u] += sz[u] * val;
        tag[u] += val;
    }
    void push_down(int u) {
        if (tag[u]) {
            modify(l_ch, tag[u]);
            modify(r_ch, tag[u]);
            tag[u] = 0;
        }
    }
    intType query(int l, int r) {
        l += n, r += n + 1;
        for (int i = log_n + 1; i; i--) push_down(l >> i), push_down(r >> i);
        intType rst = 0;
        for (; l < r; l >>= 1, r >>= 1) {
            if (l & 1) rst += tr[l++];
            if (r & 1) rst += tr[--r];
        }
        return rst;
    }
    void update(int l, int r, intType val) {
        l += n, r += n + 1;
        for (int i = log_n + 1; i; i--) push_down(l >> i), push_down(r >> i);
        for (int u = 0, v = 0; l < r; l >>= 1, r >>= 1) {
            if (l & 1) u = l, modify(l++, val);
            if (r & 1) v = r, modify(--r, val);
            do push_up(u >>= 1);
            while (l == r and u);
            do push_up(v >>= 1);
            while (l == r and v);
        }
    }
};