Classic algorithm: Segment Tree

  • Classic algorithm: Segment Tree
    • 1. Introduction to Segment Tree
    • 2. Segment Tree algorithm implementation
      • 1. Principle Description
      • 2. Vanilla code implementation
      • 3. Optimal design (1)
      • 4. Optimized design (2)
    • 3. Examination of examples
      • 1. Leetcode 2659
        • 1. Problem-solving ideas
        • 2. Code implementation
    • 4. Reference link

1. Introduction to Segment Tree

Segment Tree, Chinese name Segment Tree, is also a classic programming algorithm. Its most commonly used scenario is to query certain features within the interval range in some fixed but frequently updated arrays, such as the maximum value, Min, interval and so on.

Specifically, for an array with a fixed length, we can update its internal values, and then query the values in a certain range, such as summing or finding the maximum or minimum value.

Let’s take the summation as an example. In fact, updating the value alone or summing the range can be done in the

o

(

1

)

O(1)

O(1) complexity, however, the two are not compatible. If we want to

o

(

1

)

O(1)

Update the value in the time complexity of O(1), then the sum in the range will be a

o

(

N

)

O(N)

Operations with O(N) time complexity; conversely, if we use cumulative arrays, then we can use

o

(

1

)

O(1)

Find the sum of ranges in O(1) time complexity, but at this time to update each value, you must update the cumulative array synchronously, which will become a

o

(

N

)

O(N)

Operations of O(N) complexity.

Therefore, if we need to update the value and query the range feature equally frequently, the above method is not acceptable in terms of time complexity.

And one of the core ideas of Segment Tree is to borrow the structure of the binary tree to save each segment of the array to a node of the binary tree. At this time, we can update the value or query the range in the

o

(

l

o

g

N

)

O(logN)

It can be realized within the time complexity of O(logN), so as to improve the overall computing efficiency.

The following picture is an example picture of a typical minimum solution Segment Tree found on the Internet.

[External link picture transfer failed, the source site may have an anti-leeching mechanism, it is recommended to save the picture and upload it directly (img-vJZzM1oH-1684668249399)(./imgs/segmentTree_fig01.png)]

Next, let’s take a look at the specific algorithm implementation of Segment Tree, that is, how it works in

o

(

l

o

g

N

)

O(logN)

The update of the value and the query of the features within the range are realized within the time complexity of O(logN).

2. Segment Tree algorithm implementation

1. Principle explanation

As shown in the figure above, the main body of the Segment Tree is a binary tree, each node of which represents the nature of elements within a certain interval range, and then its left and right nodes are a bisection of the interval represented by the parent node, and the tree A leaf node is a specific value in a specific array.

Therefore, each update of an element at a specific position in the array is a binary search, so the time complexity of retrieving a certain value is

o

(

l

o

g

N

)

O(logN)

O(logN).

However, every time a specific value is changed, it will correspondingly affect the eigenvalue of its interval, so we need to modify the eigenvalue of its interval synchronously. We can trace from the leaf node up to the root node, and thus also a

o

(

l

o

g

N

)

O(logN)

Operations with O(logN) time complexity.

Finally, we examine how to obtain a certain feature within a range. Since any range can split the combination of certain nodes in the above segment tree, we only need to iterate to find these nodes and combine them together to obtain the answer we need.

2. Vanilla code implementation

We give the pseudocode implementation of python’s segment tree as follows:

class SegmentTreeNode:
    def __init__(self, val, lbound, rbound, lchild=None, rchild=None):
        self.val = val
        self.lbound = lbound
        self.rbound = rbound
        self.lchild = lchild
        self.rchild = rchild

class SegmentTree:
    def __init__(self, arr):
        self. length = len(arr)
        self.root = self.build(0, self.length-1, arr)
        self.vals = arr

    def feature_func(self, lval, rval):
        # get the target feature, such as sum, min or max.
        raise NotImplementError()

    def build(self, lbound, rbound, arr):
        if lbound == rbound:
            root = SegmentTreeNode(arr[lbound], lbound, rbound)
        else:
            mid = (lbound + rbound) // 2
            lchild = self.build(lbound, mid, arr)
            rchild = self. build(mid + 1, rbound, arr)
            val = self.feature_func(lchild.val, rchild.val)
            root = SegmentTreeNode(val, lbound, rbound, lchild, rchild)
        return root

    def update(self, idx, val):
        self.vals[idx] = val
        self._update(idx, val, self.root)
        return

    def _update(self, idx, val, root):
        if root.lbound == root.rbound:
            assert(root. lbound == idx)
            root.val = val
            return
        mid = (root. lbound + root. rbound) // 2
        if idx <= mid:
            self._update(idx, val, root.lchild)
        else:
            self._update(idx, val, root.rchild)
        root.val = self.feature_func(root.lchild.val, root.rchild.val)
        return

    def query(self, lb, rb):
        return self._query(lb, rb, self.root)

    def _query(self, lb, rb, root):
        if lb == root.lbound and rb == root.rbound:
            return root.val
        mid = (root. lbound + root. rbound) // 2
        if rb <= mid:
            return self._query(lb, rb, root.lchild)
        elif lb > mid:
            return self._query(lb, rb, root.rchild)
        else:
            lval = self._query(lb, mid, root.lchild)
            rval = self._query(mid + 1, rb, root.rchild)
            return self. feature_func(lval, rval)

For different tasks, we only need to modify the corresponding feature_func accordingly.

Some typical cases are as follows:

  1. Find the maximum value in the range

    def feature_func(self, lval, rval):
        return max(lval, rval)
    
  2. Find the minimum value in the range

    def feature_func(self, lval, rval):
        return min(lval, rval)
    
  3. Find the sum of elements in a range

    def feature_func(self, lval, rval):
        return lval + rval
    

3. Optimal design (1)

On the other hand, because in fact any binary tree can be expressed by an array, in fact, we can also optimize the implementation of the above code.

class SegmentTree:
    def __init__(self, arr):
        self. length = len(arr)
        self. tree = [0 for _ in range(4 * self. length)]
        self. vals = deepcopy(arr)
        self. build(1, arr, 0, self. length - 1)

    def feature_func(self, lval, rval):
        return lval + rval

    def build(self, node, arr, lb, rb):
        if lb == rb:
            self.tree[node] = arr[lb]
        else:
            mid = (lb + rb) // 2
            lval = self.build(2*node, arr, lb, mid)
            rval = self.build(2*node + 1, arr, mid + 1, rb)
            self.tree[node] = self.feature_func(lval, rval)
        return self. tree[node]

    def _update(self, idx, val, node, lb, rb):
        if lb == rb:
            assert(lb == idx)
            self.tree[node] = val
        else:
            mid = (lb + rb) // 2
            if idx <= mid:
                self._update(idx, val, 2*node, lb, mid)
            else:
                self._update(idx, val, 2*node + 1, mid + 1, rb)
            self.tree[node] = self.feature_func(self.tree[2*node], self.tree[2*node + 1])
        return

    def update(self, idx, val):
        self.vals[idx] = val
        self._update(idx, val, 1, 0, self. length-1)
        return

    def _query(self, left, right, node, lb, rb):
        if left == lb and right == rb:
            return self. tree[node]
        mid = (lb + rb) // 2
        if right <= mid:
            return self._query(left, right, 2*node, lb, mid)
        elif left > mid:
            return self._query(left, right, 2*node + 1, mid + 1, rb)
        else:
            lval = self._query(left, mid, 2*node, lb, mid)
            rval = self._query(mid + 1, right, 2*node + 1, mid + 1, rb)
            return self. feature_func(lval, rval)

    def query(self, lb, rb):
        return self._query(lb, rb, 1, 0, self. length-1)

Similarly, some typical segment tree feature functions are given as follows:

  1. Find the maximum value in the range

    def feature_func(self, lval, rval):
        return max(lval, rval)
    
  2. Find the minimum value in the range

    def feature_func(self, lval, rval):
        return min(lval, rval)
    
  3. Find the sum of elements in a range

    def feature_func(self, lval, rval):
        return lval + rval
    

However, it should be noted that although in principle any tree containing

no

no

In fact, n leaf node binary tree only needs

2

no

?

1

2n-1

2n?1 nodes can be expressed, but since the binary tree here is not always complete and binary tree, in fact we need some redundant nodes to ensure that all nodes can be stored, we actually need at most

4

no

4n

4n nodes are used to store tree nodes.

This will lead to a part of performance loss and space waste, so we can further optimize the above code.

4. Optimal design (2)

The optimized python code implementation is given as follows:

class SegmentTree:
    def __init__(self, arr):
        self. length = len(arr)
        self. tree = self. build(arr)

    def feature_func(self, *args):
        # get the target feature, such as sum, min or max.
        raise NotImplementError()

    def build(self, arr):
        n = len(arr)
        tree = [0 for _ in range(2*n)]
        for i in range(n):
            tree[i + n] = arr[i]
        for i in range(n-1, 0, -1):
            tree[i] = self.feature_func(tree[2*i], tree[2*i + 1])
        return tree

    def update(self, idx, val):
        idx = idx + self. length
        self. tree[idx] = val
        while idx > 1:
            self. tree[idx // 2] = self. feature_func(self. tree[idx], self. tree[idx ^ 1])
            idx = idx // 2
        return

    def query(self, lb, rb):
        lb + = self. length
        rb += self. length
        nodes = []
        while lb < rb:
            if lb % 2 == 1:
                nodes.append(self.tree[lb])
                lb + = 1
            if rb % 2 == 0:
                nodes.append(self.tree[rb])
                rb -= 1
            lb = lb // 2
            rb = rb // 2
        if lb == rb:
            nodes.append(self.tree[rb])
        return self. feature_func(*nodes)

Similarly, some typical segment tree feature functions are given as follows:

  1. Find the maximum value in the range

    def feature_func(self, *args):
        return max(args)
    
  2. Find the minimum value in the range

    def feature_func(self, *args):
        return min(args)
    
  3. Find the sum of elements in a range

    def feature_func(self, *args):
        return sum(args)
    

Of course, the more common implementation on the Internet is to use bit operations, specifically:

class SegmentTree:
    def __init__(self, arr):
        self. length = len(arr)
        self. tree = self. build(arr)

    def feature_func(self, *args):
        # get the target feature, such as sum, min or max.
        raise NotImplementError()

    def build(self, arr):
        n = len(arr)
        tree = [0 for _ in range(2*n)]
        for i in range(n):
            tree[i + n] = arr[i]
        for i in range(n-1, 0, -1):
            tree[i] = self.feature_func(tree[i<<1], tree[(i<<1) | 1])
        return tree

    def update(self, idx, val):
        idx = idx + self. length
        self. tree[idx] = val
        while idx > 1:
            self.tree[idx>>1] = self.feature_func(self.tree[idx], self.tree[idx^1])
            idx = idx>>1
        return

    def query(self, lb, rb):
        lb + = self. length
        rb += self. length
        nodes = []
        while lb < rb:
            if lb & 1 == 1:
                nodes.append(self.tree[lb])
                lb + = 1
            if rb & 1 == 0:
                nodes.append(self.tree[rb])
                rb -= 1
            lb = lb >> 1
            rb = rb >> 1
        if lb == rb:
            nodes.append(self.tree[rb])
        return self. feature_func(*nodes)

3. Example inspection

1. Leetcode 2659

Topic link:

  • 2659. Make Array Empty

1. Problem-solving ideas

The idea of this question is actually okay. We always delete elements sequentially, so we only need to sort the elements and then take out the index of the corresponding element to know the index distance that needs to be moved each time an element is deleted, and the real number of moves That is, the number of elements remaining between the two indexes is reduced by one.

Therefore, we only need to use a segment tree for range sum processing.

2. Code implementation

Given the python code implementation is as follows:

class SegmentTree:
    def __init__(self, arr):
        self. length = len(arr)
        self. tree = self. build(arr)

    def feature_func(self, *args):
        return sum(args)

    def build(self, arr):
        n = len(arr)
        tree = [0 for _ in range(2*n)]
        for i in range(n):
            tree[i + n] = arr[i]
        for i in range(n-1, 0, -1):
            tree[i] = self.feature_func(tree[i<<1], tree[(i<<1) | 1])
        return tree

    def update(self, idx, val):
        idx = idx + self. length
        self. tree[idx] = val
        while idx > 1:
            self.tree[idx>>1] = self.feature_func(self.tree[idx], self.tree[idx^1])
            idx = idx>>1
        return

    def query(self, lb, rb):
        lb + = self. length
        rb += self. length
        nodes = []
        while lb < rb:
            if lb & 1 == 1:
                nodes.append(self.tree[lb])
                lb + = 1
            if rb & 1 == 0:
                nodes.append(self.tree[rb])
                rb -= 1
            lb = lb >> 1
            rb = rb >> 1
        if lb == rb:
            nodes.append(self.tree[rb])
        return self. feature_func(*nodes)
        
class Solution:
    def countOperationsToEmptyArray(self, nums: List[int]) -> int:
        n = len(nums)
        index = [i for i in range(n)]
        index = sorted(index, key=lambda x: nums[x])
        
        status = [1 for _ in range(n)]
        segment_tree = SegmentTree(status)
        prev = 0
        res = 0
        for idx in index:
            if idx >= prev:
                res + = segment_tree.query(prev, idx) - 1
            else:
                res + = segment_tree.query(0, idx) + segment_tree.query(prev, n-1) - 1
            prev = idx
            segment_tree. update(idx, 0)
        return res + n

Submit the code for evaluation and get: time-consuming 6159ms, occupying 31.2MB of memory.

4. Reference link

  • https://www.hackerearth.com/practice/data-structures/advanced-data-structures/segment-trees/tutorial/
  • https://www.geeksforgeeks.org/segment-tree-sum-of-given-range/
  • https://www.geeksforgeeks.org/segment-tree-efficient-implementation/