Open In App

How to Extract the Decision Rules from scikit-learn Decision-tree?

Improve
Improve
Like Article
Like
Save
Share
Report

You might have already learned how to build a Decision-Tree Classifier, but might be wondering how the scikit-learn actually does that. So, in this article, we will cover this in a step-by-step manner. You can run the code in sequence, for better understanding. 

Decision-Tree uses tree-splitting criteria for splitting the nodes into sub-nodes until each splitting becomes pure with respect to the classes or targets. In each splitting, to know the purity of splitting we calculate a Gini Impurity/Gini Index, which ranges between 0 & 1. For pure splitting, the Gini Index is 0, and the tree stops there, or else the tree continues to split for the non-zero values. The splitting criteria are chosen by an algorithm, such that the Gini index always remains minimum for each split. This algorithm is also called CART (Classification and Regression Trees). This can also be done by calculating Entropy instead of Gini Impurity. To extract the decision rules from the decision tree we use the sci-kit-learn library. Let’s see by an example, how this is done.

Importing Libraries

Python libraries make it very easy for us to handle the data and perform typical and complex tasks with a single line of code.

  • Pandas – This library helps to load the data frame in a 2D array format and has multiple functions to perform analysis tasks in one go.
  • Numpy – Numpy arrays are very fast and can perform large computations in a very short time.
  • Matplotlib– This library is used to draw visualizations.
  • Sklearn – This module contains multiple libraries having pre-implemented functions to perform tasks from data preprocessing to model development and evaluation.

Python3




import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder
from sklearn.tree import DecisionTreeClassifier
  
from sklearn import tree


Importing Dataset

In this example, we will look into a binary-classification problem, where the label for the salary of an employee greater than 100k is 1 and less than 100k is 0. The code to load the data is as follows:

Python3




url = 'https://raw.githubusercontent.com/\
Stitaprajna/Practise-codes/main/salaries.csv'
df = pd.read_csv(url)
df


Output:

The Employee dataset with job type, degree and salary structure

The Employee dataset with job type, degree, and salary structure  

Next, we will build a Decision-tree using sci-kit-learn, and then finally we will plot the tree using sklearn.tree.plot_tree(). The code is as follows:

Python3




# Setting the Input
input = df.drop('salary_more_then_100k', axis=1)
  
# Label Encoding the Ordinal columns in 
# the dataset
le = LabelEncoder()
input['company_n'] = le.fit_transform(input['company'])
input['jobs_n'] = le.fit_transform(input['job'])
input['degree_n'] = le.fit_transform(input['degree'])
  
# Building the Descision-Tree
input_n = input[['company_n','jobs_n','degree_n']]
model = DecisionTreeClassifier()
model.fit(input_n, df.salary_more_then_100k)
  
# Creating the tree plot
tree.plot_tree(model, filled=True)
plt.rcParams['figure.figsize'] = [10,10]


Output:

The Binary Decision-Tree created showing the tree splitting criteria and Gini Index for the Employee Dataset

The Binary Decision-Tree created showing the tree splitting criteria and Gini Index for the Employee Dataset

Features of the model

Features of the model

Now, let’s try to understand this diagram and try to extract the decision rules used by sklearn for splitting. 

  • If we begin from the Root Node, which is the topmost light blue box, you can see that the splitting criteria are X[0] <= 0.5, with a Gini index of 0.469. So, there will be splitting into two categories based on the fact that, if it satisfies the splitting criteria of X[0] or not. The ones that satisfy the criteria go to the right & the others go left of the tree.
  • Now, we have two nodes and again based on the Gini Index the splitting happens. If the Gini Index is 0.0, the grouped data is considered pure and no further splitting occurs. 
  • You might be able to see the ‘deep orange’ and ‘deep blue’ boxes, with zero Gini index, these are actually the pure class (0 or 1). The color is ‘blue’ or ‘orange’ based on the fact that it’s either 0 or 1 class (salary_more_then_100k). The density of the color defines its purity, and if the color is ‘white’, then the grouped data has equal numbers of 0 and 1 classes. 


Last Updated : 30 Dec, 2022
Like Article
Save Article
Previous
Next
Share your thoughts in the comments
Similar Reads