Stratified Sampling in Python Programming

Hey fellow Python coder! In this tutorial, we will learn what Stratified Sampling is and how we can implement the same using Python programming. Let’s start with learning about what sampling is in the upcoming section.
What is Sampling?
Let’s pretend you have a little rack of books at a library. If we want to know what kinds of books are there on the rack, reading every single book on the rack is very time-consuming. Sampling in this example would be like taking a few books from that rack, reading a bit from each one, and using that knowledge to gain an understanding of the different kinds of stories available.
On Sampling the books on the rack, you’re simply getting a sneak peek into the variety without reading every book. It’s a means to explore and understand what’s there in a more manageable way. Sampling allows you to make quicker judgments and get a sense of the complete collection without having to go through each book individually.
What is Stratified Sampling?
Let’s imagine all the books in the racks are not of the same category and might have various categories of books. If a reader randomly selects five books, there is also a possibility that all five books belong to the same category. But we wish to avoid this situation.
To make sure, the reader explores a lot of variety of books, we will divide all the books in the rack into various categories like Mystery, Horror, Romance, and other categories. These categories in terms of Stratified Sampling are known as stratas
. The reader will pick books randomly from each stratum to get a variety of books to read and enjoy.
Python Code Implementation for Stratified Sampling
In this section, we will learn how to implement stratified sampling on real-world datasets. For this tutorial, we will use iris
dataset under sklearn
library.
Step 1 – Importing Modules
import numpy as np import pandas as pd import matplotlib.pyplot as plt from sklearn.datasets import load_iris
Step 2 – Load Iris Dataset
We will load the dataset using the load_iris()
function and then save the dataset along with the target
in a DataFrame
and also store the total number of different flower datasets present in the iris dataset using the unique
function.
iris = load_iris() data = pd.DataFrame(data= np.c_[iris['data'], iris['target']], columns= iris['feature_names'] + ['target']) num_classes = len(np.unique(iris.target))
Step 3 – Visualize the Initial Dataset
Data Visualization is simply done with the help of matplotlib
library using the code snippet below:
plt.figure(figsize=(10, 6)) for class_label in range(num_classes): class_data = data[data['target'] == class_label] plt.scatter(class_data['sepal length (cm)'], class_data['sepal width (cm)'], label=iris.target_names[class_label]) plt.title('Original Dataset Visualization') plt.xlabel('Sepal Length in cm') plt.ylabel('Sepal Width in cm') plt.legend() plt.show()
The resulting plot is as follows:
Step 4 – Apply Stratified Sampling on the Dataset
For applying Stratified Sampling on a given data
and results in num_samples_per_class
amount of sampled data using the function given below. In the function, we will iterate over the unique target labels (classes) and then randomly sample num_samples_per_class
instances from each class. These sampled instances are stored in a list samples
. The function concatenates all the sampled datasets from different various classes and returns the resulting data frame.
def applyStratifiedSampling(data, num_samples_per_class): samples = [] for class_label in range(num_classes): class_samples = data[data['target'] == class_label].sample(num_samples_per_class) samples.append(class_samples) return pd.concat(samples) stratifiedSampleData = applyStratifiedSampling(data, 5) print("Stratified Dataset is : \n", stratifiedSampleData)
The resulting sampled data is as follows. We can see that there are a total of 15 data points and 5 random data points from each of the three target classes.
Step 5 – Visualization of Stratified Sampled Dataset
The visualization of the stratified dataset is done in a similar manner using the matplotlib
library which is displayed in the code snippet and outputs below.
plt.figure(figsize=(10, 6)) for class_label in range(num_classes): class_data = stratifiedSampleData[stratifiedSampleData['target'] == class_label] plt.scatter(class_data['sepal length (cm)'], class_data['sepal width (cm)'], label=iris.target_names[class_label]) plt.title('Stratified Sampling Dataset') plt.xlabel('Sepal Length in cm') plt.ylabel('Sepal Width in cm') plt.legend() plt.show()
The final sampled data points in a scatter plot are as follows:
I hope now you can understand what stratified sampling is and how to implement the same using Python programming.
Also Read:
- What is Reservoir Sampling? Perform it using the program in Python.
- Thompson Sampling for Multi-Armed Bandit Problem in Python
- random.sample() vs random.choice() in Python
Happy Learning!
Leave a Reply