Pair plots using Scatter matrix in Pandas

Checking for collinearity among attributes of a dataset, is one of the most important steps in data preprocessing. A good way to understand the correlation among the features, is to create scatter plots for each pair of attributes. Pandas has a function scatter_matrix(), for this purpose. scatter_matrix() can be used to easily generate a group of scatter plots between all pairs of numerical features. It creates a plot for each numerical feature against every other numerical feature and also a histogram for each of them.

Syntax : pandas.plotting.scatter_matrix(frame)
Parameters :
frame : the dataframe to be plotted.

In the example below, we will create scatter plots on this dataset.

The dataset contains prices and other statistics about the houses in the California district.

filter_none

edit
close

play_arrow

link
brightness_4
code

import pandas as pd
  
# loading the dataset
data = pd.read_csv('housing.csv')
  
# inspecting the data
data.info()

chevron_right


Output :




RangeIndex: 20640 entries, 0 to 20639
Data columns (total 10 columns):
 #   Column              Non-Null Count  Dtype  
---  ------              --------------  -----  
 0   longitude           20640 non-null  float64
 1   latitude            20640 non-null  float64
 2   housing_median_age  20640 non-null  float64
 3   total_rooms         20640 non-null  float64
 4   total_bedrooms      20433 non-null  float64
 5   population          20640 non-null  float64
 6   households          20640 non-null  float64
 7   median_income       20640 non-null  float64
 8   median_house_value  20640 non-null  float64
 9   ocean_proximity     20640 non-null  object 
dtypes: float64(9), object(1)
memory usage: 1.6+ MB

Creating the scatter plots

Let us select three numeric columns; median_house_value, housing_median_age and median_income, for plotting. Note that Pandas plots depend on Matplotlib, so it needs to be imported first.

filter_none

edit
close

play_arrow

link
brightness_4
code

import matplotlib.pyplot as plt
from pandas.plotting import scatter_matrix
  
# selecting three numerical features
features = ['median_house_value', 'housing_median_age',
            'median_income']
   
# plotting the scatter matrix
# with the features
scatter_matrix(data[features])
plt.show()

chevron_right


Output :

Each scatter plot in the matrix helps us understand the correlation between the corresponding pair of attributes. As we can see, median_income and median_house_value are quite strongly correlated. The main diagonal contains the histograms for each attribute.

Attention geek! Strengthen your foundations with the Python Programming Foundation Course and learn the basics.

To begin with, your interview preparations Enhance your Data Structures concepts with the Python DS Course.




My Personal Notes arrow_drop_up

Check out this Author's contributed articles.

If you like GeeksforGeeks and would like to contribute, you can also write an article using contribute.geeksforgeeks.org or mail your article to contribute@geeksforgeeks.org. See your article appearing on the GeeksforGeeks main page and help other Geeks.

Please Improve this article if you find anything incorrect by clicking on the "Improve Article" button below.


Article Tags :

Be the First to upvote.


Please write to us at contribute@geeksforgeeks.org to report any issue with the above content.