Implementing Quick Select in Python

If I ask you to think about an algorithm to find the kth smallest element in a list of integers, your answer would probably be this: sort the list first and then extract the element at index k-1. This is a simple and effective solution. The time complexity of the above algorithm would be dependent on the sorting algorithm used.

But how can we find the kth smallest element in a list without (completely) sorting the list? The answer is Quick Select. And we can implement it using a simple technique: partitioning, which is also used in quicksort. 


How Does Quick Select Work?

In order to understand how this algorithm works, let’s first look at the pseudocode:

  1. Choose a pivot p  
  2. Partition the array in two sub-arrays w.r.t. p (same partitioning as in quicksort)
  3. LEFT –> elements smaller than or equal to p  
  4. RIGHT–>  elements greater than p  
  5. If index(pivot) == k:  
  6.     Return pivot (or index of pivot)  
  7. If k > index(pivot)  
  8.     QuickSelect(LEFT)  
  9. Else:  
  10.     QuickSelect(RIGHT)

The logic of the algorithm is extremely evident from the above pseudocode: if the index of the pivot after partitioning is the same as ‘k’ then return the pivot; else if k is greater than the index of the pivot then we recur the left side of the partition; else we recur on the right side of the partition.

NOTE: This can be easily converted into an algorithm to find kth largest element by simply changing the condition in line 7 (from ‘>’ to ‘<’)  of the pseudocode.

Complexity analysis

Considering an input list of size n:

  • Best case time complexity: O(n) when the first chosen pivot is also the kth smallest element.
  • Worst-case time complexity: O(n^2) 

The worst-case occurs when we are extremely unlucky in our pivot choices and our partitions only remove one element from the list at a time. Hence, our list sizes in each of the recursive calls would reduce by 1. 

This would result in the following time complexity: O( (n) + (n -1) + (n – 2) +……+3+2+1) which is equal to O(n^2)

NOTE: Although the worst-case complexity is O(n^2), in practice the average time complexity turns out to be O(n).

Implementation of Quick Select in Python

We already discussed a lot in theory. Now its time to code in Python to implement the Quick Select technique. Below is the given code:

import math

def quickselect(list_of_numbers, k):
    Input: a list of numbers and an integer 'k'.
    Output: kth smallest element in the list.
    Complexity: best case: O(n)
                worst case: O(n^2)
    quick_selected= _kthSmallest(list_of_numbers, k, 0, len(list_of_numbers)-1)
    if quick_selected!=math.inf:
        print('The ' + str(k)+ 'th smallest element of the given list is ' + str(quick_selected))

        print('k-th element does not exist')

def _kthSmallest(arr, k, start, end):
    private helper function for quickselect
    # checking if k is smaller than 
    # number of elements in the list
    if (k > 0 and k <= end - start + 1): 
        # Partition the array with last 
        # element as the pivot and get 
        # position of pivot element in 
        # sorted array 
        pivot_index = _partition(arr, start, end) 
        # if position of the pivot
        # after partition is same as k 
        if (pivot_index - start == k - 1): 
            return arr[pivot_index] 
        # if position of the pivot 
        # is greater than k then
        # recursive call _kthSmallest 
        # on the left partition of the pivot
        if (pivot_index - start > k - 1): 
            return _kthSmallest(arr, k, start, pivot_index - 1) 
        # Else recursive call for right partition  
        return _kthSmallest(arr,   k - pivot_index + start - 1, pivot_index + 1, end) 
    return math.inf

def _partition(arr, l, r): 
    """ private helper function
    Input: a list and two integers: 
    l: start index of the list to be partitioned
    r: end index of the list to be partitioned

    Output: index of the pivot after partition (using arr[r] as the pivot)

    pivot = arr[r] 
    i = l 
    for j in range(l, r): 
        if arr[j] <= pivot: 
            arr[i], arr[j] = arr[j], arr[i] 
            i += 1
    arr[i], arr[r] = arr[r], arr[i] 
    return i


Now let’s run an example:



The 4th smallest element of the given list is 3

Thank you for reading this article.

Leave a Reply

Your email address will not be published. Required fields are marked *