How to Plot Correlation Matrix in Python

A dataset contains many variables. Where some variables depend on one another, and some may be independent. For creating a better model we must understand how variables of the dataset related to one another. A correlation matrix help to learn about the relationship between the variables of the dataset. In this article, we will learn how to calculate and plot a correlation matrix using Python.

A correlation can be positive or negative and sometimes it can be neutral also.

  • Positive correlation: Both variables depend on one another
  • Negative correlation: Both variables are not dependent on each other.
  • Neutral correlation: Both variables are independent.

The dataset used for the demo can download from here.

Correlation Matrix in Python

We will Seaborn module to plot the correlation matrix. Python has an inbuilt corr() method to calculate the correlation of a dataset

Step1: Import the required modules

import numpy as np
# pandas used to read CSV files
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
%matplotlib inline

Step2:  Import the data

  • Use the read_csv() method to read the CSV file.
  • Use the head() method to print the first n rows of the dataset.
train_data = pd.read_csv('/content/drive/MyDrive/Colab Notebooks/Dataset/mobile_price.csv')
train_data.head()

Output

mobile_price

Step3: Select the columns

The dataset contains many columns, but we are going to select only a few columns.

Note: You can also try on all the columns of the dataset.

columns_show = ['battery_power', 'dual_sim', 'four_g', 'touch_screen', 'price_range', 'ram']

Step4: Generate a correlation matrix

We directly use corr() method to calculate the correlation of the dataset

# train_data[columns_show] used to select the columns of the train_data that are only in coloumns_show
corr_matrix = train_data[columns_show].corr()
corr_matrix

Step5: Plot the Correlation matrix

The heatmap is used to plot the correlation matrix. annot = True helps to show correlation value in the plot.

sns.heatmap(corr_matrix, annot= True)
plt.show()

Output

Also, refer

Leave a Reply

Your email address will not be published.