Stratified Sampling in Python Programming

Stratified Sampling in Python

Hey fellow Python coder! In this tutorial, we will learn what Stratified Sampling is and how we can implement the same using Python programming. Let’s start with learning about what sampling is in the upcoming section.

What is Sampling?

Let’s pretend you have a little rack of books at a library. If we want to know what kinds of books are there on the rack, reading every single book on the rack is very time-consuming. Sampling in this example would be like taking a few books from that rack, reading a bit from each one, and using that knowledge to gain an understanding of the different kinds of stories available.

On Sampling the books on the rack, you’re simply getting a sneak peek into the variety without reading every book. It’s a means to explore and understand what’s there in a more manageable way. Sampling allows you to make quicker judgments and get a sense of the complete collection without having to go through each book individually.

What is Stratified Sampling?

Let’s imagine all the books in the racks are not of the same category and might have various categories of books. If a reader randomly selects five books, there is also a possibility that all five books belong to the same category. But we wish to avoid this situation.

What is Stratified Sampling

To make sure, the reader explores a lot of variety of books, we will divide all the books in the rack into various categories like Mystery, Horror, Romance, and other categories. These categories in terms of Stratified Sampling are known as stratas. The reader will pick books randomly from each stratum to get a variety of books to read and enjoy.

Python Code Implementation for Stratified Sampling

In this section, we will learn how to implement stratified sampling on real-world datasets. For this tutorial, we will use iris dataset under sklearn library.

Step 1 – Importing Modules

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris

Step 2 – Load Iris Dataset

We will load the dataset using the load_iris() function and then save the dataset along with the target in a DataFrame and also store the total number of different flower datasets present in the iris dataset using the unique function.

iris = load_iris()
data = pd.DataFrame(data= np.c_[iris['data'], iris['target']],
                    columns= iris['feature_names'] + ['target'])
num_classes = len(np.unique(iris.target))

Step 3 – Visualize the Initial Dataset

Data Visualization is simply done with the help of matplotlib library using the code snippet below:

plt.figure(figsize=(10, 6))
for class_label in range(num_classes):
    class_data = data[data['target'] == class_label]
    plt.scatter(class_data['sepal length (cm)'], class_data['sepal width (cm)'],
                label=iris.target_names[class_label])

plt.title('Original Dataset Visualization')
plt.xlabel('Sepal Length in cm')
plt.ylabel('Sepal Width in cm')
plt.legend()
plt.show()

The resulting plot is as follows:

Stratified Sampling - Original Visualization

Step 4 – Apply Stratified Sampling on the Dataset

For applying Stratified Sampling on a given data and results in num_samples_per_class amount of sampled data using the function given below. In the function, we will iterate over the unique target labels (classes) and then randomly sample num_samples_per_class instances from each class. These sampled instances are stored in a list samples. The function concatenates all the sampled datasets from different various classes and returns the resulting data frame.

def applyStratifiedSampling(data, num_samples_per_class):
    samples = []
    for class_label in range(num_classes):
        class_samples = data[data['target'] == class_label].sample(num_samples_per_class)
        samples.append(class_samples)
    return pd.concat(samples)

stratifiedSampleData = applyStratifiedSampling(data, 5)
print("Stratified Dataset is : \n", stratifiedSampleData)

The resulting sampled data is as follows. We can see that there are a total of 15 data points and 5 random data points from each of the three target classes.

Stratified Sampling - Original Visualization

Step 5 – Visualization of Stratified Sampled Dataset

The visualization of the stratified dataset is done in a similar manner using the matplotlib library which is displayed in the code snippet and outputs below.

plt.figure(figsize=(10, 6))
for class_label in range(num_classes):
    class_data = stratifiedSampleData[stratifiedSampleData['target'] == class_label]
    plt.scatter(class_data['sepal length (cm)'], class_data['sepal width (cm)'],
                label=iris.target_names[class_label])

plt.title('Stratified Sampling Dataset')
plt.xlabel('Sepal Length in cm')
plt.ylabel('Sepal Width in cm')
plt.legend()
plt.show()

The final sampled data points in a scatter plot are as follows:

Stratified Sampling - Original Visualization

I hope now you can understand what stratified sampling is and how to implement the same using Python programming.

Also Read:

  1. What is Reservoir Sampling? Perform it using the program in Python.
  2. Thompson Sampling for Multi-Armed Bandit Problem in Python
  3. random.sample() vs random.choice() in Python

 Happy Learning!

Leave a Reply

Your email address will not be published. Required fields are marked *