Open In App
Related Articles

Sklearn.StratifiedShuffleSplit() function in Python

Improve Article
Save Article
Like Article

In this article, we’ll learn about the StratifiedShuffleSplit cross validator from sklearn library which gives train-test indices to split the data into train-test sets. 

What is StratifiedShuffleSplit?

StratifiedShuffleSplit is a combination of both ShuffleSplit and StratifiedKFold. Using StratifiedShuffleSplit the proportion of distribution of class labels is almost even between train and test dataset. The major difference between StratifiedShuffleSplit and StratifiedKFold (shuffle=True) is that in StratifiedKFold, the dataset is shuffled only once in the beginning and then split into the specified number of folds. This discards any chances of overlapping of the train-test sets. 
However, in StratifiedShuffleSplit the data is shuffled each time before the split is done and this is why there’s a greater chance that overlapping might be possible between train-test sets. 

Syntax: sklearn.model_selection.StratifiedShuffleSplit(n_splits=10, *, test_size=None, train_size=None, random_state=None) 


n_splits: int, default=10

Number of re-shuffling & splitting iterations.

test_size: float or int, default=None

If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the test split. 

train_size: float or int, default=None

If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the train split. 

random_state: int

Controls the randomness of the training and testing indices produced. 

Below is the Implementation.

Step 1) Import required modules.


# import the libraries
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn import preprocessing
from sklearn.metrics import accuracy_score
from sklearn.model_selection import StratifiedShuffleSplit

Step 2) Load the dataset and identify the dependent and independent variables. 

The dataset can be downloaded from here.


# convert data set into dataframe
churn_df = pd.read_csv(r"ChurnData.csv")
# assign dependent and independent variables
X = churn_df[['tenure', 'age', 'address', 'income',
              'ed', 'employ', 'equip',   'callcard', 'wireless']]
y = churn_df['churn'].astype('int')

Step 3) Pre-process data.


# data pre-processing
X = preprocessing.StandardScaler().fit(X).transform(X)

Step 4) Create object of StratifiedShuffleSplit Class.


# use StratifiedShuffleSplit()
sss = StratifiedShuffleSplit(n_splits=4, test_size=0.5,
sss.get_n_splits(X, y)


Step 5) Call the instance and split the data frame into training sample and testing sample. The split() function returns indices for the train-test samples. Use a regression algorithm and compare accuracy for each predicted value.


scores = []
# using regression to get predicted data
rf = RandomForestClassifier(n_estimators=40, max_depth=7)
for train_index, test_index in sss.split(X, y):
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index], y_train)
    pred = rf.predict(X_test)
    scores.append(accuracy_score(y_test, pred))
# get accuracy of each prediction


Whether you're preparing for your first job interview or aiming to upskill in this ever-evolving tech landscape, GeeksforGeeks Courses are your key to success. We provide top-quality content at affordable prices, all geared towards accelerating your growth in a time-bound manner. Join the millions we've already empowered, and we're here to do the same for you. Don't miss out - check it out now!

Last Updated : 10 Oct, 2022
Like Article
Save Article
Similar Reads
Complete Tutorials