Open In App

Using Early Stopping to Reduce Overfitting in Neural Networks

Overfitting is a common challenge in training neural networks. It occurs when a model learns to memorize the training data rather than generalize patterns from it, leading to poor performance on unseen data. While various regularization techniques like dropout and weight decay can help combat overfitting, early stopping stands out as a simple yet effective method to prevent neural networks from overfitting. In this article, we will demonstrate how can we reduce overfitting in neural networks.

What is Early Stopping?

Early stopping is a form of regularization that halts the training process when the performance of the model on a validation dataset starts to degrade. Instead of training the model until convergence, early stopping monitors the validation error during training and stops the training process when the validation error begins to increase.

Advantages of Early Stopping

Using Early Stopping to Reduce Overfitting in Neural Networks in Python

To demonstrate the effectiveness of early stopping in reducing overfitting, let's train two neural network models on the MNIST dataset: one with early stopping and another without. We will compare their performances on both the training and validation datasets.

Step 1: Data Loading and Preprocessing

import tensorflow as tf
from tensorflow.keras.datasets import mnist

# Load MNIST dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Preprocess data: Scale pixel values to [0, 1]
x_train, x_test = x_train / 255.0, x_test / 255.0

Step 2: Model Definition and Compilation

Two neural network models are defined using the Sequential API provided by Keras. Sequential models are created by stacking layers sequentially, making them easy to build and understand. Both models consist of the following layers:

After defining the model architecture, both models are compiled using the Adam optimizer, which is an adaptive learning rate optimization algorithm, and sparse categorical cross-entropy loss, suitable for multi-class classification tasks where the target labels are integers.

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Dropout

# Define model architecture without early stopping
model_without_early_stopping = Sequential([
    Flatten(input_shape=(28, 28)),
    Dense(128, activation='relu'),
    Dropout(0.2),
    Dense(10, activation='softmax')
])

# Compile model without early stopping
model_without_early_stopping.compile(optimizer='adam',
                                     loss='sparse_categorical_crossentropy',
                                     metrics=['accuracy'])

Step 3: Model Training

# Train model without early stopping
history_without_early_stopping = model_without_early_stopping.fit(x_train, y_train, epochs=20, validation_split=0.2)

Step 4: Applying Early Stopping to Prevent Overfitting

from tensorflow.keras.callbacks import EarlyStopping

# Define model architecture with early stopping
model_with_early_stopping = Sequential([
    Flatten(input_shape=(28, 28)),
    Dense(128, activation='relu'),
    Dropout(0.2),
    Dense(10, activation='softmax')
])

# Compile model with early stopping
model_with_early_stopping.compile(optimizer='adam',
                                  loss='sparse_categorical_crossentropy',
                                  metrics=['accuracy'])

# Define early stopping callback
early_stopping = EarlyStopping(monitor='val_loss', patience=3)

# Train model with early stopping
history_with_early_stopping = model_with_early_stopping.fit(x_train, y_train, epochs=20, validation_split=0.2, callbacks=[early_stopping])

Step 5: Model Evaluation

# Evaluate models on test data
test_loss_without_early_stopping, test_acc_without_early_stopping = model_without_early_stopping.evaluate(x_test, y_test)
test_loss_with_early_stopping, test_acc_with_early_stopping = model_with_early_stopping.evaluate(x_test, y_test)

# Print test accuracies
print("Test Accuracy without Early Stopping:", test_acc_without_early_stopping)
print("Test Accuracy with Early Stopping:", test_acc_with_early_stopping)

Complete Code

import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Dropout
from tensorflow.keras.callbacks import EarlyStopping

# Load MNIST dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Preprocess data: Scale pixel values to [0, 1]
x_train, x_test = x_train / 255.0, x_test / 255.0

# Define model architecture without early stopping
model_without_early_stopping = Sequential([
    Flatten(input_shape=(28, 28)),
    Dense(128, activation='relu'),
    Dropout(0.2),
    Dense(10, activation='softmax')
])

# Compile model without early stopping
model_without_early_stopping.compile(optimizer='adam',
                                     loss='sparse_categorical_crossentropy',
                                     metrics=['accuracy'])

# Train model without early stopping
history_without_early_stopping = model_without_early_stopping.fit(x_train, y_train, epochs=20, validation_split=0.2)

# Define model architecture with early stopping
model_with_early_stopping = Sequential([
    Flatten(input_shape=(28, 28)),
    Dense(128, activation='relu'),
    Dropout(0.2),
    Dense(10, activation='softmax')
])

# Compile model with early stopping
model_with_early_stopping.compile(optimizer='adam',
                                  loss='sparse_categorical_crossentropy',
                                  metrics=['accuracy'])

# Define early stopping callback
early_stopping = EarlyStopping(monitor='val_loss', patience=3)

# Train model with early stopping
history_with_early_stopping = model_with_early_stopping.fit(x_train, y_train, epochs=20, validation_split=0.2, callbacks=[early_stopping])

# Evaluate models on test data
test_loss_without_early_stopping, test_acc_without_early_stopping = model_without_early_stopping.evaluate(x_test, y_test)
test_loss_with_early_stopping, test_acc_with_early_stopping = model_with_early_stopping.evaluate(x_test, y_test)

# Print test accuracies
print("Test Accuracy without Early Stopping:", test_acc_without_early_stopping)
print("Test Accuracy with Early Stopping:", test_acc_with_early_stopping)

Output:

Test Accuracy without Early Stopping: 0.9782999753952026
Test Accuracy with Early Stopping: 0.9790999889373779

To validate the efficacy of early stopping, we conducted an experiment training two neural network models on the MNIST dataset: one with early stopping and another without. Despite both models achieving high accuracy on the training data, the model trained with early stopping exhibited a slightly lower test accuracy compared to the model trained without it.

At first glance, this result might seem counterintuitive. However, it's crucial to consider the broader context and the advantages of early stopping. While the model trained without early stopping may achieve marginally higher accuracy on the test dataset in this specific experiment, it is more likely to be overfitted to the training data. Conversely, the model trained with early stopping is less prone to overfitting and is more likely to generalize better to unseen data.

The slight difference in test accuracy between the two models underscores the trade-off between maximizing training performance and optimizing for generalization. In many real-world scenarios, preventing overfitting and achieving good generalization are paramount, making early stopping a valuable regularization technique despite occasional fluctuations in test accuracy.

Article Tags :