Open In App

How to Extract the Decision Rules from scikit-learn Decision-tree?

You might have already learned how to build a Decision-Tree Classifier, but might be wondering how the scikit-learn actually does that. So, in this article, we will cover this in a step-by-step manner. You can run the code in sequence, for better understanding. 

Decision-Tree uses tree-splitting criteria for splitting the nodes into sub-nodes until each splitting becomes pure with respect to the classes or targets. In each splitting, to know the purity of splitting we calculate a Gini Impurity/Gini Index, which ranges between 0 & 1. For pure splitting, the Gini Index is 0, and the tree stops there, or else the tree continues to split for the non-zero values. The splitting criteria are chosen by an algorithm, such that the Gini index always remains minimum for each split. This algorithm is also called CART (Classification and Regression Trees). This can also be done by calculating Entropy instead of Gini Impurity. To extract the decision rules from the decision tree we use the sci-kit-learn library. Let’s see by an example, how this is done.



Importing Libraries

Python libraries make it very easy for us to handle the data and perform typical and complex tasks with a single line of code.




import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder
from sklearn.tree import DecisionTreeClassifier
  
from sklearn import tree

Importing Dataset

In this example, we will look into a binary-classification problem, where the label for the salary of an employee greater than 100k is 1 and less than 100k is 0. The code to load the data is as follows:






url = 'https://raw.githubusercontent.com/\
Stitaprajna/Practise-codes/main/salaries.csv'
df = pd.read_csv(url)
df

Output:

The Employee dataset with job type, degree, and salary structure  

Next, we will build a Decision-tree using sci-kit-learn, and then finally we will plot the tree using sklearn.tree.plot_tree(). The code is as follows:




# Setting the Input
input = df.drop('salary_more_then_100k', axis=1)
  
# Label Encoding the Ordinal columns in 
# the dataset
le = LabelEncoder()
input['company_n'] = le.fit_transform(input['company'])
input['jobs_n'] = le.fit_transform(input['job'])
input['degree_n'] = le.fit_transform(input['degree'])
  
# Building the Descision-Tree
input_n = input[['company_n','jobs_n','degree_n']]
model = DecisionTreeClassifier()
model.fit(input_n, df.salary_more_then_100k)
  
# Creating the tree plot
tree.plot_tree(model, filled=True)
plt.rcParams['figure.figsize'] = [10,10]

Output:

The Binary Decision-Tree created showing the tree splitting criteria and Gini Index for the Employee Dataset

Features of the model

Now, let’s try to understand this diagram and try to extract the decision rules used by sklearn for splitting. 


Article Tags :