Python中的树形数据结构: 线段树的区间查询

2023-04-11 00:00:00 区间 数据结构 线段

线段树(Segment Tree)是一种树形数据结构,用于高效地解决区间查询问题。它能够在 $O(\log{n})$ 时间内进行区间查询和单点修改。

在线段树中,每个节点代表一个区间,叶子节点代表单个元素。每个节点都有一个值,表示该区间的信息,如最大值、最小值、求和等等。这些值都可以通过递归计算得到。

线段树的建立过程可以用递归实现。每次递归将当前节点的区间一分为二,然后递归处理左右两个子区间,直到区间缩小到只有一个元素,即到达叶子节点。递归回溯时,将左右子节点的值合并,得到当前节点的值。

区间查询也可以用递归实现。首先判断当前节点所代表的区间是否完全包含目标区间,如果不是则分别递归左右子节点进行查询,然后将结果合并。如果当前节点所代表的区间完全包含目标区间,则直接返回当前节点的值。

单点修改也可以用递归实现。找到包含修改点的叶子节点,更新其值,然后递归向上更新祖先节点的值。

下面是一个使用线段树实现区间查询的例子,其中的元素为字符串:

class SegmentTree:
    def __init__(self, arr):
        self.arr = arr
        self.tree = [None] * (4 * len(arr))
        self._build(1, 0, len(arr) - 1)

    def _build(self, v, l, r):
        if l == r:
            self.tree[v] = self.arr[l:l+1]
        else:
            m = (l + r) // 2
            self._build(2*v, l, m)
            self._build(2*v+1, m+1, r)
            self.tree[v] = self.tree[2*v] + self.tree[2*v+1]

    def query(self, ql, qr):
        return self._query(1, 0, len(self.arr) - 1, ql, qr)

    def _query(self, v, l, r, ql, qr):
        if ql > qr:
            return ""
        if l == ql and r == qr:
            return self.tree[v]
        m = (l + r) // 2
        left = self._query(2*v, l, m, ql, min(qr, m))
        right = self._query(2*v+1, m+1, r, max(ql, m+1), qr)
        return left + right

    def update(self, i, val):
        self._update(1, 0, len(self.arr) - 1, i, val)

    def _update(self, v, l, r, i, val):
        if l == r:
            self.tree[v] = val
        else:
            m = (l + r) // 2
            if i <= m:
                self._update(2*v, l, m, i, val)
            else:
                self._update(2*v+1, m+1, r, i, val)
            self.tree[v] = self.tree[2*v] + self.tree[2*v+1]

上面的例子可以用来实现字符串的区间查询和单点修改。

示例使用:

arr = "pidancode.com"
st = SegmentTree(arr)

# 查询区间[0, 3]
print(st.query(0, 3))  # "pida"

# 修改第一个字符为"q"
st.update(0, "q")
print(st.query(0, 3))  # "qida"

注意,上面的查询和修改操作中,左右端点均为闭区间。如果需要修改为半开区间,在处理时需要做相应的调整。同时需要注意空字符串的处理。

相关文章