- Classic algorithm: Segment Tree
- 1. Introduction to Segment Tree
- 2. Segment Tree algorithm implementation
- 1. Principle Description
- 2. Vanilla code implementation
- 3. Optimal design (1)
- 4. Optimized design (2)
- 3. Examination of examples
- 1. Leetcode 2659
- 1. Problem-solving ideas
- 2. Code implementation
- 1. Leetcode 2659
- 4. Reference link
1. Introduction to Segment Tree
Segment Tree, Chinese name Segment Tree, is also a classic programming algorithm. Its most commonly used scenario is to query certain features within the interval range in some fixed but frequently updated arrays, such as the maximum value, Min, interval and so on.
Specifically, for an array with a fixed length, we can update its internal values, and then query the values in a certain range, such as summing or finding the maximum or minimum value.
Let’s take the summation as an example. In fact, updating the value alone or summing the range can be done in the
o
(
1
)
O(1)
O(1) complexity, however, the two are not compatible. If we want to
o
(
1
)
O(1)
Update the value in the time complexity of O(1), then the sum in the range will be a
o
(
N
)
O(N)
Operations with O(N) time complexity; conversely, if we use cumulative arrays, then we can use
o
(
1
)
O(1)
Find the sum of ranges in O(1) time complexity, but at this time to update each value, you must update the cumulative array synchronously, which will become a
o
(
N
)
O(N)
Operations of O(N) complexity.
Therefore, if we need to update the value and query the range feature equally frequently, the above method is not acceptable in terms of time complexity.
And one of the core ideas of Segment Tree is to borrow the structure of the binary tree to save each segment of the array to a node of the binary tree. At this time, we can update the value or query the range in the
o
(
l
o
g
N
)
O(logN)
It can be realized within the time complexity of O(logN), so as to improve the overall computing efficiency.
The following picture is an example picture of a typical minimum solution Segment Tree found on the Internet.
[External link picture transfer failed, the source site may have an anti-leeching mechanism, it is recommended to save the picture and upload it directly (img-vJZzM1oH-1684668249399)(./imgs/segmentTree_fig01.png)]
Next, let’s take a look at the specific algorithm implementation of Segment Tree, that is, how it works in
o
(
l
o
g
N
)
O(logN)
The update of the value and the query of the features within the range are realized within the time complexity of O(logN).
2. Segment Tree algorithm implementation
1. Principle explanation
As shown in the figure above, the main body of the Segment Tree is a binary tree, each node of which represents the nature of elements within a certain interval range, and then its left and right nodes are a bisection of the interval represented by the parent node, and the tree A leaf node is a specific value in a specific array.
Therefore, each update of an element at a specific position in the array is a binary search, so the time complexity of retrieving a certain value is
o
(
l
o
g
N
)
O(logN)
O(logN).
However, every time a specific value is changed, it will correspondingly affect the eigenvalue of its interval, so we need to modify the eigenvalue of its interval synchronously. We can trace from the leaf node up to the root node, and thus also a
o
(
l
o
g
N
)
O(logN)
Operations with O(logN) time complexity.
Finally, we examine how to obtain a certain feature within a range. Since any range can split the combination of certain nodes in the above segment tree, we only need to iterate to find these nodes and combine them together to obtain the answer we need.
2. Vanilla code implementation
We give the pseudocode implementation of python’s segment tree as follows:
class SegmentTreeNode: def __init__(self, val, lbound, rbound, lchild=None, rchild=None): self.val = val self.lbound = lbound self.rbound = rbound self.lchild = lchild self.rchild = rchild class SegmentTree: def __init__(self, arr): self. length = len(arr) self.root = self.build(0, self.length-1, arr) self.vals = arr def feature_func(self, lval, rval): # get the target feature, such as sum, min or max. raise NotImplementError() def build(self, lbound, rbound, arr): if lbound == rbound: root = SegmentTreeNode(arr[lbound], lbound, rbound) else: mid = (lbound + rbound) // 2 lchild = self.build(lbound, mid, arr) rchild = self. build(mid + 1, rbound, arr) val = self.feature_func(lchild.val, rchild.val) root = SegmentTreeNode(val, lbound, rbound, lchild, rchild) return root def update(self, idx, val): self.vals[idx] = val self._update(idx, val, self.root) return def _update(self, idx, val, root): if root.lbound == root.rbound: assert(root. lbound == idx) root.val = val return mid = (root. lbound + root. rbound) // 2 if idx <= mid: self._update(idx, val, root.lchild) else: self._update(idx, val, root.rchild) root.val = self.feature_func(root.lchild.val, root.rchild.val) return def query(self, lb, rb): return self._query(lb, rb, self.root) def _query(self, lb, rb, root): if lb == root.lbound and rb == root.rbound: return root.val mid = (root. lbound + root. rbound) // 2 if rb <= mid: return self._query(lb, rb, root.lchild) elif lb > mid: return self._query(lb, rb, root.rchild) else: lval = self._query(lb, mid, root.lchild) rval = self._query(mid + 1, rb, root.rchild) return self. feature_func(lval, rval)
For different tasks, we only need to modify the corresponding feature_func
accordingly.
Some typical cases are as follows:
-
Find the maximum value in the range
def feature_func(self, lval, rval): return max(lval, rval)
-
Find the minimum value in the range
def feature_func(self, lval, rval): return min(lval, rval)
-
Find the sum of elements in a range
def feature_func(self, lval, rval): return lval + rval
3. Optimal design (1)
On the other hand, because in fact any binary tree can be expressed by an array, in fact, we can also optimize the implementation of the above code.
class SegmentTree: def __init__(self, arr): self. length = len(arr) self. tree = [0 for _ in range(4 * self. length)] self. vals = deepcopy(arr) self. build(1, arr, 0, self. length - 1) def feature_func(self, lval, rval): return lval + rval def build(self, node, arr, lb, rb): if lb == rb: self.tree[node] = arr[lb] else: mid = (lb + rb) // 2 lval = self.build(2*node, arr, lb, mid) rval = self.build(2*node + 1, arr, mid + 1, rb) self.tree[node] = self.feature_func(lval, rval) return self. tree[node] def _update(self, idx, val, node, lb, rb): if lb == rb: assert(lb == idx) self.tree[node] = val else: mid = (lb + rb) // 2 if idx <= mid: self._update(idx, val, 2*node, lb, mid) else: self._update(idx, val, 2*node + 1, mid + 1, rb) self.tree[node] = self.feature_func(self.tree[2*node], self.tree[2*node + 1]) return def update(self, idx, val): self.vals[idx] = val self._update(idx, val, 1, 0, self. length-1) return def _query(self, left, right, node, lb, rb): if left == lb and right == rb: return self. tree[node] mid = (lb + rb) // 2 if right <= mid: return self._query(left, right, 2*node, lb, mid) elif left > mid: return self._query(left, right, 2*node + 1, mid + 1, rb) else: lval = self._query(left, mid, 2*node, lb, mid) rval = self._query(mid + 1, right, 2*node + 1, mid + 1, rb) return self. feature_func(lval, rval) def query(self, lb, rb): return self._query(lb, rb, 1, 0, self. length-1)
Similarly, some typical segment tree feature functions are given as follows:
-
Find the maximum value in the range
def feature_func(self, lval, rval): return max(lval, rval)
-
Find the minimum value in the range
def feature_func(self, lval, rval): return min(lval, rval)
-
Find the sum of elements in a range
def feature_func(self, lval, rval): return lval + rval
However, it should be noted that although in principle any tree containing
no
no
In fact, n leaf node binary tree only needs
2
no
?
1
2n-1
2n?1 nodes can be expressed, but since the binary tree here is not always complete and binary tree, in fact we need some redundant nodes to ensure that all nodes can be stored, we actually need at most
4
no
4n
4n nodes are used to store tree nodes.
This will lead to a part of performance loss and space waste, so we can further optimize the above code.
4. Optimal design (2)
The optimized python code implementation is given as follows:
class SegmentTree: def __init__(self, arr): self. length = len(arr) self. tree = self. build(arr) def feature_func(self, *args): # get the target feature, such as sum, min or max. raise NotImplementError() def build(self, arr): n = len(arr) tree = [0 for _ in range(2*n)] for i in range(n): tree[i + n] = arr[i] for i in range(n-1, 0, -1): tree[i] = self.feature_func(tree[2*i], tree[2*i + 1]) return tree def update(self, idx, val): idx = idx + self. length self. tree[idx] = val while idx > 1: self. tree[idx // 2] = self. feature_func(self. tree[idx], self. tree[idx ^ 1]) idx = idx // 2 return def query(self, lb, rb): lb + = self. length rb += self. length nodes = [] while lb < rb: if lb % 2 == 1: nodes.append(self.tree[lb]) lb + = 1 if rb % 2 == 0: nodes.append(self.tree[rb]) rb -= 1 lb = lb // 2 rb = rb // 2 if lb == rb: nodes.append(self.tree[rb]) return self. feature_func(*nodes)
Similarly, some typical segment tree feature functions are given as follows:
-
Find the maximum value in the range
def feature_func(self, *args): return max(args)
-
Find the minimum value in the range
def feature_func(self, *args): return min(args)
-
Find the sum of elements in a range
def feature_func(self, *args): return sum(args)
Of course, the more common implementation on the Internet is to use bit operations, specifically:
class SegmentTree: def __init__(self, arr): self. length = len(arr) self. tree = self. build(arr) def feature_func(self, *args): # get the target feature, such as sum, min or max. raise NotImplementError() def build(self, arr): n = len(arr) tree = [0 for _ in range(2*n)] for i in range(n): tree[i + n] = arr[i] for i in range(n-1, 0, -1): tree[i] = self.feature_func(tree[i<<1], tree[(i<<1) | 1]) return tree def update(self, idx, val): idx = idx + self. length self. tree[idx] = val while idx > 1: self.tree[idx>>1] = self.feature_func(self.tree[idx], self.tree[idx^1]) idx = idx>>1 return def query(self, lb, rb): lb + = self. length rb += self. length nodes = [] while lb < rb: if lb & 1 == 1: nodes.append(self.tree[lb]) lb + = 1 if rb & 1 == 0: nodes.append(self.tree[rb]) rb -= 1 lb = lb >> 1 rb = rb >> 1 if lb == rb: nodes.append(self.tree[rb]) return self. feature_func(*nodes)
3. Example inspection
1. Leetcode 2659
Topic link:
- 2659. Make Array Empty
1. Problem-solving ideas
The idea of this question is actually okay. We always delete elements sequentially, so we only need to sort the elements and then take out the index of the corresponding element to know the index distance that needs to be moved each time an element is deleted, and the real number of moves That is, the number of elements remaining between the two indexes is reduced by one.
Therefore, we only need to use a segment tree for range sum processing.
2. Code implementation
Given the python code implementation is as follows:
class SegmentTree: def __init__(self, arr): self. length = len(arr) self. tree = self. build(arr) def feature_func(self, *args): return sum(args) def build(self, arr): n = len(arr) tree = [0 for _ in range(2*n)] for i in range(n): tree[i + n] = arr[i] for i in range(n-1, 0, -1): tree[i] = self.feature_func(tree[i<<1], tree[(i<<1) | 1]) return tree def update(self, idx, val): idx = idx + self. length self. tree[idx] = val while idx > 1: self.tree[idx>>1] = self.feature_func(self.tree[idx], self.tree[idx^1]) idx = idx>>1 return def query(self, lb, rb): lb + = self. length rb += self. length nodes = [] while lb < rb: if lb & 1 == 1: nodes.append(self.tree[lb]) lb + = 1 if rb & 1 == 0: nodes.append(self.tree[rb]) rb -= 1 lb = lb >> 1 rb = rb >> 1 if lb == rb: nodes.append(self.tree[rb]) return self. feature_func(*nodes) class Solution: def countOperationsToEmptyArray(self, nums: List[int]) -> int: n = len(nums) index = [i for i in range(n)] index = sorted(index, key=lambda x: nums[x]) status = [1 for _ in range(n)] segment_tree = SegmentTree(status) prev = 0 res = 0 for idx in index: if idx >= prev: res + = segment_tree.query(prev, idx) - 1 else: res + = segment_tree.query(0, idx) + segment_tree.query(prev, n-1) - 1 prev = idx segment_tree. update(idx, 0) return res + n
Submit the code for evaluation and get: time-consuming 6159ms, occupying 31.2MB of memory.
4. Reference link
- https://www.hackerearth.com/practice/data-structures/advanced-data-structures/segment-trees/tutorial/
- https://www.geeksforgeeks.org/segment-tree-sum-of-given-range/
- https://www.geeksforgeeks.org/segment-tree-efficient-implementation/