Open In App

Pair plots using Scatter matrix in Pandas

Improve
Improve
Like Article
Like
Save
Share
Report

Checking for collinearity among attributes of a dataset, is one of the most important steps in data preprocessing. A good way to understand the correlation among the features, is to create scatter plots for each pair of attributes. Pandas has a function scatter_matrix(), for this purpose. scatter_matrix() can be used to easily generate a group of scatter plots between all pairs of numerical features. It creates a plot for each numerical feature against every other numerical feature and also a histogram for each of them.

Syntax : pandas.plotting.scatter_matrix(frame)
Parameters :
frame : the dataframe to be plotted.

The dataset contains prices and other statistics about the houses in the California district.




import pandas as pd
  
# loading the dataset
data = pd.read_csv('housing.csv')
  
# inspecting the data
data.info()


Output :


RangeIndex: 20640 entries, 0 to 20639
Data columns (total 10 columns):
 #   Column              Non-Null Count  Dtype  
---  ------              --------------  -----  
 0   longitude           20640 non-null  float64
 1   latitude            20640 non-null  float64
 2   housing_median_age  20640 non-null  float64
 3   total_rooms         20640 non-null  float64
 4   total_bedrooms      20433 non-null  float64
 5   population          20640 non-null  float64
 6   households          20640 non-null  float64
 7   median_income       20640 non-null  float64
 8   median_house_value  20640 non-null  float64
 9   ocean_proximity     20640 non-null  object 
dtypes: float64(9), object(1)
memory usage: 1.6+ MB

Creating the scatter plots

Let us select three numeric columns; median_house_value, housing_median_age and median_income, for plotting. Note that Pandas plots depend on Matplotlib, so it needs to be imported first.




import matplotlib.pyplot as plt
from pandas.plotting import scatter_matrix
  
# selecting three numerical features
features = ['median_house_value', 'housing_median_age',
            'median_income']
   
# plotting the scatter matrix
# with the features
scatter_matrix(data[features])
plt.show()


Output :

Each scatter plot in the matrix helps us understand the correlation between the corresponding pair of attributes. As we can see, median_income and median_house_value are quite strongly correlated. The main diagonal contains the histograms for each attribute.



Last Updated : 21 Mar, 2024
Like Article
Save Article
Previous
Next
Share your thoughts in the comments
Similar Reads