Open In App

Breast Cancer predictions using catboost

Last Updated : 08 Apr, 2024
Improve
Improve
Like Article
Like
Save
Share
Report

CatBoost is a gradient boosting algorithm that deals with the categorical features during the training process. In the article, we are going to perform prediction analysis on breast cancer dataset using CatBoost.

Breast Cancer Detection using CatBoost

We aim to provide a comprehensive pipeline for training a CatBoostClassifier model on the Breast Cancer dataset, evaluating its performance, and making predictions on new data points.

For installing catBoost library use following command:

pip install catboost

Step 1: Importing Libraries and Loading Dataset

We have loaded necessary libraries such as Pandas, Matplotlib, Seaborn, Scikit-learn’s Breast Cancer dataset, NumPy, CatBoostClassifier, and evaluation metrics like accuracy, precision, recall, and F1-score are imported.

The Breast Cancer dataset is loaded using load_breast_cancer() from Scikit-learn.

Python3
# Importing necessary libraries
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import load_breast_cancer
import numpy as np
from sklearn.model_selection import train_test_split
from catboost import CatBoostClassifier
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score


# Load the breast cancer dataset
data = load_breast_cancer()

# Create a DataFrame from the dataset
df = pd.DataFrame(data.data, columns=data.feature_names)
df['target'] = data.target

# Display the first few rows of the dataset
print(df.head())

Output:

   mean radius  mean texture  mean perimeter  mean area  mean smoothness  \
0        17.99         10.38          122.80     1001.0          0.11840   
1        20.57         17.77          132.90     1326.0          0.08474   
2        19.69         21.25          130.00     1203.0          0.10960   
3        11.42         20.38           77.58      386.1          0.14250   
4        20.29         14.34          135.10     1297.0          0.10030   

   mean compactness  mean concavity  mean concave points  mean symmetry  \
0           0.27760          0.3001              0.14710         0.2419   
1           0.07864          0.0869              0.07017         0.1812   
2           0.15990          0.1974              0.12790         0.2069   
3           0.28390          0.2414              0.10520         0.2597   
4           0.13280          0.1980              0.10430         0.1809   

   mean fractal dimension  ...  worst texture  worst perimeter  worst area  \
0                 0.07871  ...          17.33           184.60      2019.0   
1                 0.05667  ...          23.41           158.80      1956.0   
2                 0.05999  ...          25.53           152.50      1709.0   
3                 0.09744  ...          26.50            98.87       567.7   
4                 0.05883  ...          16.67           152.20      1575.0   

   worst smoothness  worst compactness  worst concavity  worst concave points  \
0            0.1622             0.6656           0.7119                0.2654   
1            0.1238             0.1866           0.2416                0.1860   
2            0.1444             0.4245           0.4504                0.2430   
3            0.2098             0.8663           0.6869                0.2575   
4            0.1374             0.2050           0.4000                0.1625   

   worst symmetry  worst fractal dimension  target  
0          0.4601                  0.11890       0  
1          0.2750                  0.08902       0  
2          0.3613                  0.08758       0  
3          0.6638                  0.17300       0  
4          0.2364                  0.07678       0  

[5 rows x 31 columns]

Step 2: Exploratory Data Analysis on Breast Cancer dataset

The Exploratory Data Analysis (EDA) section provides a comprehensive overview of the Breast Cancer dataset:

Summary Statistics

A summary of statistical measures for each numerical feature is presented using df.describe(). This includes count, mean, standard deviation, minimum, and maximum values, providing insights into the distribution and range of the dataset.

Python3
# Summary statistics
print(df.describe())

Output:

       mean radius  mean texture  mean perimeter    mean area  \
count   569.000000    569.000000      569.000000   569.000000   
mean     14.127292     19.289649       91.969033   654.889104   
std       3.524049      4.301036       24.298981   351.914129   
min       6.981000      9.710000       43.790000   143.500000   
25%      11.700000     16.170000       75.170000   420.300000   
50%      13.370000     18.840000       86.240000   551.100000   
75%      15.780000     21.800000      104.100000   782.700000   
max      28.110000     39.280000      188.500000  2501.000000   

       mean smoothness  mean compactness  mean concavity  mean concave points  \
count       569.000000        569.000000      569.000000           569.000000   
mean          0.096360          0.104341        0.088799             0.048919   
std           0.014064          0.052813        0.079720             0.038803   
min           0.052630          0.019380        0.000000             0.000000   
25%           0.086370          0.064920        0.029560             0.020310   
50%           0.095870          0.092630        0.061540             0.033500   
75%           0.105300          0.130400        0.130700             0.074000   
max           0.163400          0.345400        0.426800             0.201200   

       mean symmetry  mean fractal dimension  ...  worst texture  \
count     569.000000              569.000000  ...     569.000000   
mean        0.181162                0.062798  ...      25.677223   
std         0.027414                0.007060  ...       6.146258   
min         0.106000                0.049960  ...      12.020000   
25%         0.161900                0.057700  ...      21.080000   
50%         0.179200                0.061540  ...      25.410000   
75%         0.195700                0.066120  ...      29.720000   
max         0.304000                0.097440  ...      49.540000   

       worst perimeter   worst area  worst smoothness  worst compactness  \
count       569.000000   569.000000        569.000000         569.000000   
mean        107.261213   880.583128          0.132369           0.254265   
std          33.602542   569.356993          0.022832           0.157336   
min          50.410000   185.200000          0.071170           0.027290   
25%          84.110000   515.300000          0.116600           0.147200   
50%          97.660000   686.500000          0.131300           0.211900   
75%         125.400000  1084.000000          0.146000           0.339100   
max         251.200000  4254.000000          0.222600           1.058000   

       worst concavity  worst concave points  worst symmetry  \
count       569.000000            569.000000      569.000000   
mean          0.272188              0.114606        0.290076   
std           0.208624              0.065732        0.061867   
min           0.000000              0.000000        0.156500   
25%           0.114500              0.064930        0.250400   
50%           0.226700              0.099930        0.282200   
75%           0.382900              0.161400        0.317900   
max           1.252000              0.291000        0.663800   

       worst fractal dimension      target  
count               569.000000  569.000000  
mean                  0.083946    0.627417  
std                   0.018061    0.483918  
min                   0.055040    0.000000  
25%                   0.071460    0.000000  
50%                   0.080040    1.000000  
75%                   0.092080    1.000000  
max                   0.207500    1.000000  

[8 rows x 31 columns]

Counting the Target Values

The distribution of target values is examined with df[‘target’].value_counts(), revealing the number of instances belonging to each class. This is crucial for understanding the balance or imbalance between different classes in the dataset.

Python3
# Count of target values
print(df['target'].value_counts())

Output:

target
1    357
0    212
Name: count, dtype: int64

Plotting pair plot for first five variables

To explore the relationships between selected features visually, a pairplot is generated using Seaborn’s pairplot() function. The pairplot showcases pairwise scatter plots for the first five features along with the target variable. Different target classes are distinguished by unique markers (e.g., circles and squares), aiding in the identification of potential patterns or separability between classes.

Python3
# Pairplot for selected features
selected_features = ['target'] + list(df.columns[:5])  # select first 5 features for visualization
sns.pairplot(df[selected_features], hue='target', markers=["o", "s"], diag_kind='kde')
plt.show()

Output:

Screenshot-2024-04-07-230540-min-(1)


Step 3: Splitting in Test-Train Dataset

Python3
X = pd.DataFrame(data.data, columns=data.feature_names)
y = data.target

# Split data into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)


Ste 4: Building and Training Catboost Classifier

A CatBoostClassifier is initialized with specific parameters such as number of iterations, learning rate, depth, loss function, evaluation metric, random seed, and verbosity.

The CatBoost model is trained on the training data using fit(). Early stopping is implemented to prevent overfitting.

Python3
# Initialize CatBoost classifier
model = CatBoostClassifier(iterations=1000,
                           learning_rate=0.1,
                           depth=6,
                           loss_function='Logloss',
                           eval_metric='Accuracy',
                           random_seed=42,
                           verbose=False)

# Train the model
model.fit(X_train, y_train, eval_set=(X_test, y_test), early_stopping_rounds=50, verbose_eval=False)


Step 5: Model Evaluation

Predictions are made on the test set using predict() and the model is evaluated using accuracy, precision, recall, and F1-score.

Python3
# Make predictions on test set
y_pred = model.predict(X_test)

# Evaluate the model
accuracy = accuracy_score(y_test, y_pred)
precision = precision_score(y_test, y_pred)
recall = recall_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred)

print("Model Evaluation Metrics:")
print("Accuracy:", accuracy)
print("Precision:", precision)
print("Recall:", recall)
print("F1 Score:", f1)

Output:

Model Evaluation Metrics:
Accuracy: 0.9649122807017544
Precision: 0.958904109589041
Recall: 0.9859154929577465
F1 Score: 0.9722222222222222

The evaluation metrics indicate that the model is able to distinguish between malignant and benign cases quite well.

Step 6: Define a Prediction function

A function predict_breast_cancer() is defined to make predictions for new data points. It takes the trained model, input data, and feature names as inputs, and returns a prediction result in human-readable form.

Input data with certain feature values is provided, and the function predict_breast_cancer() is used to predict whether the person is benign or malignant with respect to breast cancer.

Python3
# Function to predict breast cancer
def predict_breast_cancer(model, input_data, feature_names):
    # Convert input data to DataFrame
    input_df = pd.DataFrame([input_data], columns=feature_names)
    # Make prediction
    prediction = model.predict(input_df)
    # Convert prediction to human-readable form
    if prediction[0] == 0:
        return "The person is predicted to be benign (not having breast cancer)."
    else:
        return "The person is predicted to be malignant (having breast cancer)."
      
# Input certain values for prediction
input_data = {
    'mean radius': 15.0,
    'mean texture': 20.0,
    'mean perimeter': 90.0,
    'mean area': 600.0,
    'mean smoothness': 0.1,
    # Add other feature values here
}

# Predict breast cancer
result = predict_breast_cancer(model, input_data, data.feature_names)
print("\nPrediction Result:")
print(result)

Output:

Prediction Result:
The person is predicted to be malignant (having breast cancer).

In the article, we have successfully conducted prediction analysis on bread cancer dataset.



Like Article
Suggest improvement
Share your thoughts in the comments

Similar Reads