[TopK problem] Heap-based method & divide-and-conquer strategy-based method

Description:

  1. TopK problem: For a given array, select the largest/smallest k elements, or select the kth largest/kth smallest element;
  2. This article summarizes two implementation methods, namely
    • Heap-based implementation method: Different from heap sorting, only by constructing a heap containing k elements, the largest/minimum k elements are finally obtained.
    • Method based on divide and conquer strategy: adopts the idea of quick sort to divide the original array, but unlike quick sort, only one side after division is processed each time
  3. The content of the article is compiled for personal study. If there are any errors, please correct me.

Article directory

  • 1. Heap-based approach
    • 1.1 Algorithm steps
    • 1.2 Algorithm implementation
  • 2. Method based on divide and conquer strategy
    • 2.1 Algorithm steps
    • 2.2 Algorithm implementation

1. Heap-based method

1.1 Algorithm steps

If the problem requires getting the largest k elements, you can build a small root heap containing k elements (correspondingly, if you want to find the smallest k elements, you can build a large root heap).

  1. First, use the first k elements of the original array to construct a small root heap;
  2. Traverse backward starting from the k + 1th element of the original array, and compare the size of the element with the element at the top of the heap in sequence. If it is larger than the element at the top of the heap, replace the element at the top of the heap and adjust the heap in time; otherwise, continue to traverse backward;
  3. When the array traversal is completed, the k elements stored in the small root heap are the largest k elements in the original array.

1.2 Algorithm Implementation

LeetCode related questions: 215. The Kth largest element in the array

//Use the first k elements of the array to construct a small root heap containing k elements
    //Start traversing from k + 1, and compare it with the top element of the heap each time. If the traversed element is larger than the top element of the heap, replace the top element of the heap and adjust the heap to ensure that the k elements in the heap are always the current largest k elements.
    int findKthLargest(vector<int> & amp; nums, int k) {<!-- -->
        vector<int> heap_k(nums.begin(), nums.begin() + k); //Select the first k elements in nums
        BuildMinHeap(heap_k); //Build these k elements into a small root heap
        
        for(int i=k; i<nums.size(); i + + ){<!-- -->//Start from the k + 1st element (subscript is k) and compare it with the top element of the heap in sequence
            if(nums[i] > heap_k[0]){<!-- -->
                heap_k[0] = nums[i];
                MinHeapAdjust(heap_k, 0, k);//If the traversed element is larger than the top element of the heap, replace the top element of the heap and adjust the heap
            }
        }
        return heap_k[0];//heap_k is a small root heap, and heap_k[0] is the kth largest element of the original array
    }
//Build a small root heap
    void BuildMinHeap(vector<int> & amp; nums){<!-- -->
        int n = nums.size();
        for(int i=n/2; i>=0; i--){<!-- -->//Start adjusting from the first non-leaf node
            MinHeapAdjust(nums, i, n);
        }
    }
    //Adjust the small root heap
    void MinHeapAdjust(vector<int> & nums, int i, int n){<!-- -->
        int temp = nums[i]; //Temporarily store the filtered nodes
        for(int j=i*2 + 1; j<n; j=i*2 + 1){<!-- -->//j initially points to the left child of the i node
            if(j + 1<n & amp; & amp; nums[j + 1]<nums[j]) j + + ;//Adjust j so that it points to the smaller value of the left and right children of i

            if(temp <= nums[j]) break;//If the currently filtered node temp is smaller, it means that the downward degree of this node meets the requirements of the small root heap, and the filtering can be terminated early.
            else{<!-- -->
                nums[i] = nums[j]; // Otherwise, adjust the smaller of the child nodes to the parent position
                i = j; //Update the i pointer to continue filtering downwards
            }
        }
        nums[i] = temp; //The filtered node is placed at its final position
    }

2. Method based on divide and conquer strategy

2.1 Algorithm steps

In quick sort, the most important step is pivotPos = Partition(nums, left, right), which uses an element in the array as the pivot to divide the elements with subscripts from left to right into two part, and use the pivot as the pivot to place the elements smaller than the pivot on the left, and the elements larger than the pivot on the right. By continuously dividing the array, the overall sorting is finally obtained.

This divide-and-conquer strategy of continuous division is also used in the TopK problem, but in quick sort, the left and right parts need to be processed each time. This step is simplified in the TopK problem, that is, only one side is processed at a time. Because what you are looking for is the largest/smallest k elements, you can determine whether the left or right side will be processed next by comparing the sizes of pivotPos and k.

2.2 Algorithm implementation

P.S. The algorithm about RANDOMIZED-SELECT corresponds to “Introduction to Algorithms (3rd Edition)” 9.2 Selection Algorithm with Expected Linear Time, and the algorithm about SELECT corresponds to 9.3 Selection with Worst Case Situation in Linear Time Algorithm

  • Algorithm implementation based on random selection
//Method 4: Divide and conquer method, random selection
    int findKthLargest(vector<int> & amp; nums, int k) {<!-- -->
        randomizedSelect(nums, 0, nums.size()-1, k);
        return nums[k-1];
    }

    //Division: Randomly select RANDOMIZED-SELECT
    int Partition(vector<int> & amp; nums, int left, int right){<!-- -->
        int pivotPos = rand()%(right-left + 1) + left;//Generate random numbers in the range [left,right]
        int pivot = nums[pivotPos];//Randomly select elements as pivots

        swap(nums[left], nums[pivotPos]);//Exchange the pivot element and the leftmost element, and then use the leftmost element as the pivot (called the pivot element in the algorithm book)

        //Get descending sequence
        while(left<right){<!-- -->
            while(left<right & amp; & amp; nums[right]<=pivot) right--;//Because the leftmost element is used as the pivot, the right pointer must also be moved first
            nums[left] = nums[right];
            while(left<right & amp; & amp; nums[left]>=pivot) left + + ;
            nums[right] = nums[left];
        }
        pivotPos = left;//When the left and right pointers finally meet, this position is the final position of pivot
        nums[pivotPos] = pivot;
        return pivotPos;
    }
    //Randomly select a recursive function
    void randomizedSelect(vector<int> & amp; nums, int left, int right, int k){<!-- -->
        if(left >= right) return;//recursive return condition

        int pivotPos = Partition(nums, left, right);
        if(pivotPos == k) return; //find kth
        else if(pivotPos > k) randomizedSelect(nums, left, pivotPos-1,k);//Arrange in descending order, so when pivotPos is larger than k, it means that the kth you are looking for is on the left side of the sequence
        else randomizedSelect(nums, pivotPos + 1, right, k);
    }
  • The worst-case scenario is the linear-time selection algorithm SELECT. The following content is from “Introduction to Algorithms”
  1. Divide the n elements of the input array into n/5 groups, each group has 5 elements, and at most only one group consists of the remaining n mod 5 elements;
  2. Find the median of each group in n/5 groups: first perform insertion sort on the elements of each group, and then determine the median of each group of ordered elements;
  3. For the n/5 medians found in the second step, recursively call the SELECT function to find the median x (if there is an even number of medians, x is the smaller median);
  4. Divide the array according to the median x of the median, let k be 1 more than the number of elements in the divided low area, so x is the kth small element, and there are n-k elements in the divided high area;
  5. If i=k, return

Reference study:
linear time selection problem
BFPRT–The ultimate solution to Top k problems

Select the kth largest element:

//Find the kth largest element

//divide function
int Partition(vector<int> & amp; nums, int left, int right, int pivot){<!-- -->//pivot is the median of the passed in median
    for(int index=left; index<=right; index + + ){<!-- -->//Find the subscript of pivot within the range of left and right
        if(nums[index] == pivot){<!-- -->
            swap(nums[left], nums[index]);//Swap with the leftmost element as the pivot element
            break;
        }
    }
    //descending sequence
    int i=left, j=right;
    while(i<j){<!-- -->
        while(i<j & amp; & amp; nums[j]<=pivot) j--;
        while(i<j & amp; & amp; nums[i]>=pivot) i + + ;
        swap(nums[i], nums[j]);
    }
    swap(nums[left], nums[i]);
    return i;
}

// Sort the data in the [begin, end] range and return the median subscript
int indexOfMedian(vector<int> & amp; nums, int begin, int end){<!-- -->
    sort(nums.begin() + begin, nums.begin() + end + 1, greater<int>());
    int index = begin + (end - begin)/2;
    return index;
}

int select(vector<int> & amp; nums, int left, int right, int kth){<!-- -->
    if(right-left + 1 <= 5){<!-- -->
        //If the number of elements is within 5, sort directly and return the kth this time
        sort(nums.begin() + left, nums.begin() + right + 1, greater<int>());
        return nums[left + kth -1]; //Note that the subscript should be reduced by one
    }

    int count = right - left + 1;
    int groups = count/5 + (count%5 > 0 ? 1 : 0);//How many groups are there in total?
    for(int i=0; i<groups; i + + ){<!-- -->//i is the group number, starting from 0, traversing by group
        int index = indexOfMedian(nums, left + i*5, min(left + i*5 + 4, right));//The last group needs to be processed here, and the number of elements in the last group may be less than 5
        swap(nums[left + i], nums[index]);//Swap the median to the front of the array to facilitate fetching the next time
    }

    int pivot = select(nums, left, left + groups-1, groups/2);//The median of the median, if the number is an even number, select the smaller median
    int pivotPos = Partition(nums, left, right, pivot);//Divide according to pivot

    int num_left = pivotPos - left + 1;//How many elements are there between [left, pivotPos]
    if(num_left == kth) return nums[pivotPos];//If the element with the subscript pivotPos is exactly kth element, return
    else if(num_left > kth) return select(nums, left, pivotPos-1, kth);//kth is in the left half
    else return select(nums, pivotPos + 1, right, kth-num_left);//kth is in the right half, kth-num_left is the relative position of kth in the right half
}

P.S. I don’t understand the SELECT algorithm thoroughly enough. I will continue to add this in the future.