Kullback-Leibler Divergence in Python – Machine Learning
In this tutorial, we will dive into the depths of the Kullback-Leibler Divergence (KL Divergence) method together, learn its mathematics, and apply our concepts using Python.
Kullback-Leibler Divergence Method
You have studied in your statistics course about the probability distributions. We get a probability distribution curve when random variables are plotted using mean and variance parameters. If we change these parameters, we will get different probability distribution curves.
We need some methods to compare the probability distribution function. The comparison is necessary to reach a conclusion or results in various fields such as data science, quality control, etc. Kullback-Leibler Divergence is one such method for comparing the PDFs.
Mathematics behind KL Divergence
Consider the two Probability Distribution Functions P and Q. The KL Divergence method will compare the two distributions and quantify the extra information that sets the two curves apart. To check the extra information, you must build the code on either distribution and compare accordingly. That means the KL Divergence is asymmetric.
Let’s see the formula to understand better:
From the formulae, we can conclude that the KL Divergence of P and Q will be zero only if P and Q are the same.
Please note that the notations have the usual meaning.
Python Code: Kullback-Leibler Divergence
Let’s code in Python. First, we will generate two PDF distributions using different means and variances. Then, apply the formula for getting the KL Divergence quantity. I am also generating the plot of the PDFs for better understanding.
import numpy as np # Define the mean and standard deviation for the distribution mean1 = 0.5 std_dev1 = 0.2 mean2 = 0.8 std_dev2 = 0.3 # Generate 30 random samples from a normal distribution using the above mean and variances p = np.random.normal(mean1, std_dev1, 30) q = np.random.normal(mean2,std_dev2,30) # Ensuring all values are positive p = np.abs(p) q = np.abs(q) # Normalize to form a valid probability distribution p /= np.sum(p) q /= np.sum(q) # Now, Calculate the KL Divergence kl_divergence = np.sum(p * np.log(p / q)) print("KL Divergence:", kl_divergence)
KL Divergence: 0.21726111089437736
import seaborn as sns import matplotlib.pyplot as plt fig, ax = plt.subplots(figsize=(16, 10)) # Bar plot sns.barplot(x=np.arange(len(p)), y=p, label='p (Bar)', color='blue', alpha=0.7, ax=ax) sns.barplot(x=np.arange(len(q)), y=q, label='q (Bar)', color='red', alpha=0.7, ax=ax) # Line plot sns.lineplot(x=np.arange(len(p)), y=p, label='p (Line)', color='darkgreen', ax=ax,linewidth = 2) sns.lineplot(x=np.arange(len(q)), y=q, label='q (Line)', color='black', ax=ax,linewidth = 2) plt.title('Probability Density Functions') plt.xlabel('Index') plt.ylabel('Value') plt.legend() plt.show()
Leave a Reply