Open In App

Text Classification using Decision Trees in Python

Last Updated : 18 Mar, 2024
Improve
Improve
Like Article
Like
Save
Share
Report

Text classification is the process of classifying the text documents into predefined categories. In this article, we are going to explore how we can leverage decision trees to classify the textual data.

Text Classification and Decision Trees

Text classification involves assigning predefined categories or labels to text documents based on their content. Decision trees are hierarchical tree structures that recursively partition the feature space based on the values of input features. They are particularly well-suited for classification tasks due to their simplicity, interpretability, and ability to handle non-linear relationships.

Decision Trees provide a clear and understandable model for text classification, making them an excellent choice for tasks where interpretability is as important as predictive power. Their inherent simplicity, however, might lead to challenges when dealing with very complex or nuanced text data, leading practitioners to explore more sophisticated or ensemble methods for improvement.

Implementation: Text Classification using Decision Trees

For text classification using Decision Trees in Python, we’ll use the popular 20 Newsgroups dataset. This dataset comprises around 20,000 newsgroup documents, partitioned across 20 different newsgroups. We’ll use scikit-learn to fetch the dataset, preprocess the text, convert it into a feature vector using TF-IDF vectorization, and then apply a Decision Tree classifier for classification.

Ensure you have scikit-learn installed in your environment. You can install it using pip if you haven’t already:

pip install scikit-learn

Importing Necessary Libraries

Python3
from sklearn.datasets import fetch_20newsgroups
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import classification_report, accuracy_score
import matplotlib.pyplot as plt
import numpy as np


Load the Dataset

The 20 Newsgroups dataset is loaded with specific categories for simplification. Headers, footers, and quotes are removed to focus on the text content.

Python3
# Load the dataset
categories = ['alt.atheism', 'soc.religion.christian', 'comp.graphics', 'sci.med']
newsgroups_train = fetch_20newsgroups(subset='train', categories=categories, remove=('headers', 'footers', 'quotes'))
newsgroups_test = fetch_20newsgroups(subset='test', categories=categories, remove=('headers', 'footers', 'quotes'))


Exploratory Data Analysis

This code snippet provides basic exploratory data analysis by visualizing the distribution of classes in the training and test sets and displaying sample documents.

Python3
# Display distribution of classes in the training set
class_distribution = np.bincount(y_train)
plt.bar(range(len(class_distribution)), class_distribution)
plt.xticks(range(len(class_distribution)), newsgroups_train.target_names, rotation=45)
plt.title('Distribution of Classes in Training Set')
plt.xlabel('Class')
plt.ylabel('Number of Documents')
plt.show()

Output:

EDA

Python3
# Display distribution of classes in the test set
class_distribution = np.bincount(y_test)
plt.bar(range(len(class_distribution)), class_distribution)
plt.xticks(range(len(class_distribution)), newsgroups_test.target_names, rotation=45)
plt.title('Distribution of Classes in Test Set')
plt.xlabel('Class')
plt.ylabel('Number of Documents')
plt.show()

Output:

EDA

Data Preprocessing

Text data is converted into TF-IDF feature vectors. TF-IDF (Term Frequency-Inverse Document Frequency) is a numerical statistic that reflects how important a word is to a document in a collection. This step is crucial for converting text data into a format that can be used for machine learning.

Python3
# Data preprocessing
vectorizer = TfidfVectorizer(stop_words='english')
X_train = vectorizer.fit_transform(newsgroups_train.data)
X_test = vectorizer.transform(newsgroups_test.data)
y_train = newsgroups_train.target
y_test = newsgroups_test.target


Decision Tree Classifier

A Decision Tree classifier is initialized and trained on the processed training data. Decision Trees are a non-linear predictive modeling tool that can be used for both classification and regression tasks.

Python3
# Initialize and train a Decision Tree classifier
clf = DecisionTreeClassifier(random_state=42)
clf.fit(X_train, y_train)


Model Evaluation

The trained model is used to make predictions on the test set, and the model’s performance is evaluated using accuracy and a detailed classification report, which includes precision, recall, f1-score, and support for each class.

Python3
# Make predictions
y_pred = clf.predict(X_test)
# Evaluate the model
print("Accuracy:", accuracy_score(y_test, y_pred))
print("\nClassification Report:\n", classification_report(y_test, y_pred, target_names=newsgroups_test.target_names))

Output:

Accuracy: 0.6324900133155792
Classification Report:
precision recall f1-score support
alt.atheism 0.53 0.39 0.45 319
comp.graphics 0.61 0.82 0.70 389
sci.med 0.66 0.58 0.62 396
soc.religion.christian 0.70 0.69 0.70 398
accuracy 0.63 1502
macro avg 0.62 0.62 0.62 1502
weighted avg 0.63 0.63 0.62 1502

The output demonstrates the performance of a Decision Tree classifier on a text classification task using the 20 Newsgroups dataset. An accuracy of approximately 63.25% indicates that the model correctly predicted the category of over half of the newsgroup posts in the test set. The precision, recall, and f1-score for each category show how well the model performs for individual classes. Precision indicates the model’s accuracy in labeling a class correctly, recall reflects how well the model identifies all relevant instances of a class, and the f1-score provides a balance between precision and recall. The variation across different categories (alt.atheism, comp.graphics, sci.med, soc.religion.christian) suggests that the model’s ability to correctly classify posts varies with the subject matter, performing best in ‘soc.religion.christian’ and worst in ‘alt.atheism’.

Comparison with Other Text Classification Techniques

We will compare decision trees with other popular text classification algorithms such as Random Forest and Support Vector Machines.

Text Classification using Random Forest

Python3
from sklearn.ensemble import RandomForestClassifier

# Initialize and train a Random Forest classifier
clf = RandomForestClassifier(n_estimators=100, random_state=42)
clf.fit(X_train, y_train)

# Make predictions
y_pred = clf.predict(X_test)

# Evaluate the model
print("\nClassification Report:\n", classification_report(y_test, y_pred, target_names=newsgroups_test.target_names))

Output:

Classification Report:
precision recall f1-score support

alt.atheism 0.70 0.48 0.57 319
comp.graphics 0.77 0.93 0.84 389
sci.med 0.80 0.75 0.77 396
soc.religion.christian 0.74 0.82 0.78 398

accuracy 0.76 1502
macro avg 0.75 0.74 0.74 1502
weighted avg 0.75 0.76 0.75 1502


Text Classification using SVM

Python3
from sklearn.svm import SVC

# Initialize and train an SVM classifier
clf = SVC(kernel='linear', random_state=42)
clf.fit(X_train, y_train)

# Make predictions
y_pred = clf.predict(X_test)

# Evaluate the model
print("\nClassification Report:\n", classification_report(y_test, y_pred, target_names=newsgroups_test.target_names))

Output:

Classification Report:
precision recall f1-score support

alt.atheism 0.75 0.63 0.68 319
comp.graphics 0.91 0.90 0.90 389
sci.med 0.80 0.90 0.85 396
soc.religion.christian 0.80 0.82 0.81 398

accuracy 0.82 1502
macro avg 0.82 0.81 0.81 1502
weighted avg 0.82 0.82 0.82 1502

Observations

  1. SVM outperforms both Random Forest and Decision Tree classifiers in terms of accuracy and overall performance, as indicated by the higher F1-score.
  2. Random Forest performs relatively well but slightly lags behind SVM.
  3. Decision Tree shows the lowest performance among the three classifiers, indicating the importance of choosing an appropriate algorithm for text classification tasks.


Like Article
Suggest improvement
Share your thoughts in the comments

Similar Reads