Open In App

Save and Load Models in PyTorch

Last Updated : 23 Feb, 2024
Improve
Improve
Like Article
Like
Save
Share
Report

It often happens that we need to use the already-trained models to perform some operations in our development environment. In this case, would you create the model again and again? Or, you will save the model somewhere else and load it as per the requirement. You would definitely choose the second option. So in this article, we will see how to implement the concept of saving and loading the models using PyTorch.

What is PyTorch?

PyTorch is an open-source Machine Learning Library that works on the dynamic computation graph. In the static computation approach, the models are predefined before the execution. But in dynamic computation which PyTorch follows, the structure of the graph in the Neural Network can change during the execution based on the input data. Hence, It allows to creation and training the Neural Networks to extract hidden patterns from the data.

You might think what a Neural Network is. So in simple words, a Neural Network is a collection of layers containing Nodes. These layers are interconnected with each other in which one Node processes the data and passes it to the other Node. Hence, the entire Neural Network learns and extracts the insights from the data.

Stepwise Guide to Save and Load Models in PyTorch

Now, we will see how to create a Model using the PyTorch.

Creating Model in PyTorch

To save and load the model, we will first create a Deep-Learning Model for the image classification. This model will classify the images of the handwritten digits from the MNIST Dataset. The below code implements the Convolutional Neural Network for image classification. The data is loaded and transformed into PyTorch Sensors, which are like containers to store the data.

The following code shows the creation of the PyTorch Model.

Importing Necessary Libraries

Python3




import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader


Data Transformation

The given code defines a transformation pipeline using torchvision.transforms.Compose for preprocessing image data before feeding it into a PyTorch model.

  • transforms.ToTensor(): Converts the input image (assumed to be in PIL Image format) to a PyTorch tensor. It converts the image data type to torch.FloatTensor and scales the pixel values to the range [0.0, 1.0].
  • transforms.Normalize((0.5,), (0.5,)): Normalizes the tensor image with mean and standard deviation. The provided mean and standard deviation values (0.5,) and (0.5,) respectively are used to normalize each channel of the input tensor. This transformation normalizes the tensor values to be in the range [-1.0, 1.0].

Python3




# Define transformation to apply to the data
data_transform = transforms.Compose([
    transforms.ToTensor(),  # Convert images to PyTorch tensors
    transforms.Normalize((0.5,), (0.5,))  # Normalize the pixel values to range [-1, 1]
])
 
# Download MNIST dataset and apply the transformation
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=data_transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=data_transform, download=True)
 
 
# Define data loaders to load the data in batches during training and testing
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)


Defining neural network architecture

  1. Class Definition: The code defines a class SimpleCNN that inherits from nn.Module, which is the base class for all neural network modules in PyTorch. This class represents a simple convolutional neural network (CNN).
  2. Initialization: In the __init__ method, the code defines the layers of the CNN. It includes two convolutional layers (conv1_layer and conv2_layer) with specified kernel sizes and padding, and two fully connected layers (fc1_layer and fc2_layer) with specified input and output sizes.
  3. Forward Pass: The forward method defines the forward pass of the network. It applies a ReLU activation function after each convolutional layer and uses max pooling with a kernel size of 2 and stride of 2 to downsample the feature maps. The output of the second convolutional layer is flattened before being passed to the fully connected layers.
  4. View Operation: The view operation reshapes the output of the second convolutional layer to be compatible with the input size of the first fully connected layer. The -1 argument in view indicates that the size of that dimension should be inferred based on the other dimensions.
  5. Model Instance: Finally, an instance of the SimpleCNN class is created and assigned to the variable cnn_model. This instance represents the actual neural network that can be trained and used for inference.

Python3




# Here we are adding convolution layer and fully connected layers in neural network
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1_layer = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2_layer = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.fc1_layer = nn.Linear(32 * 7 * 7, 128)
        self.fc2_layer = nn.Linear(128, 10)
 
    # Adding ReLU Activation function Max Pooling Layer
    def forward(self, inputs):
        new_input = torch.relu(self.conv1_layer(inputs))
        new_input = torch.max_pool2d(new_input, kernel_size=2, stride=2)
        new_input = torch.relu(self.conv2_layer(new_input))
        new_input = torch.max_pool2d(new_input, kernel_size=2, stride=2)
        new_input = new_input.view(-1, 32 * 7 * 7)
        new_input = torch.relu(self.fc1_layer(new_input))
        new_input = self.fc2_layer(new_input)
        return new_input
 
 
# Creating Model Instance
cnn_model = SimpleCNN()


Loss Function and Optimizer

  1. Loss Function: nn.CrossEntropyLoss() is used as the loss function. This loss function is commonly used for classification problems with multiple classes. It calculates the cross-entropy loss between the predicted class probabilities and the actual class labels.
  2. Optimizer: optim.Adam is used as the optimizer. Adam is a popular optimization algorithm that computes adaptive learning rates for each parameter. It is well-suited for training deep neural networks. The optimizer is initialized with the parameters of the cnn_model and a learning rate of 0.001.

Python3




# Define loss function and optimizer
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(cnn_model.parameters(), lr=0.001)


Training the model

The code implements the following steps:

  1. Outer Loop (Epochs): The code iterates over 5 epochs using a for loop. An epoch is a single pass through the entire dataset.
  2. Inner Loop (Batches): Within each epoch, the code iterates over batches of data using train_loader, which presumably contains batches of input data (inputs) and their corresponding labels (labels).
  3. Zero Gradients: Before the backward pass (loss.backward()), optimizer.zero_grad() is called to zero out the gradients of the model parameters. This is necessary because PyTorch accumulates gradients by default.
  4. Forward and Backward Pass:
    • outputs = cnn_model(inputs) performs the forward pass, where the model processes the input data to generate predictions (outputs).
    • loss = loss_func(outputs, labels) calculates the loss between the predicted outputs and the actual labels.
    • loss.backward() computes the gradients of the loss with respect to the model parameters, enabling backpropagation.
    • optimizer.step() updates the model parameters based on the computed gradients, using the optimization algorithm (Adam in this case) to adjust the weights.
  5. Loss Calculation: Within the inner loop, running_loss accumulates the total loss across batches. At the end of each epoch, the average loss per batch is printed to monitor the training progress.

Python3




# Train model
for epoch in range(5):  # Train for 5 epochs
    running_loss = 0.0
    for inputs, labels in train_loader:
        optimizer.zero_grad()  # Zero the gradients
        outputs = cnn_model(inputs)  # Forward pass
        loss = loss_func(outputs, labels)  # Calculate the loss
        loss.backward()  # Backward pass
        optimizer.step()  # Update weights
 
 
        running_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}")


Output:

Epoch 1, Loss: 0.22154594235159933
Epoch 2, Loss: 0.05766747533348697
Epoch 3, Loss: 0.04144403319505514
Epoch 4, Loss: 0.029859573355312946
Epoch 5, Loss: 0.024109310584392515

Testing The Model

Python




# Test model
correct_predictions = 0
total_samples = 0
with torch.no_grad():
    for inputs, labels in test_loader:
        outputs = cnn_model(inputs)
        _, predicted_labels = torch.max(outputs.data, 1)
        total_samples += labels.size(0)
        correct_predictions += (predicted_labels == labels).sum().item()
 
print(f"Accuracy of test set: {100 * correct_predictions / total_samples}%")


Output:

Accuracy of test set: 99.16%

Saving and Loading Model

Method 1: Using torch.save() and torch.load()

The following code shows method to save and load the model using the built-in function provided by the torch module. The torch.save() method directly saves model object into the file and the torch.load() loads the model back into the memory.

Python




# Save the model
torch.save(cnn_model.state_dict(), 'cnn_model.pth')
 
# Load the model
loaded_model = SimpleCNN()
loaded_model.load_state_dict(torch.load('cnn_model.pth'))
 
# Set the model to evaluation mode
loaded_model.eval()


Output:

SimpleCNN(
(conv1_layer): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv2_layer): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(fc1_layer): Linear(in_features=1568, out_features=128, bias=True)
(fc2_layer): Linear(in_features=128, out_features=10, bias=True)
)

Method 2: Using model.state_dict()

Now, let us see another way to save and load the model using the state_dict() method. This method stores the parameters of the created model. When the model is loaded, a new model with the same architecture is created. Then, the parameters of the new model are replaced with the stored parameters. Since only parameters are stored, this method is memory efficient. The following code snippet illustrates this method.

Python




# Saving the model
torch.save(cnn_model.state_dict(), 'cnn_model.pth')
 
# Loading the model
loaded_model = SimpleCNN()
loaded_model.load_state_dict(torch.load('cnn_model.pth'))
print(loaded_model)


Output:

SimpleCNN(
(conv1_layer): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv2_layer): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(fc1_layer): Linear(in_features=1568, out_features=128, bias=True)
(fc2_layer): Linear(in_features=128, out_features=10, bias=True)
)

Method 3: Saving and Loading using the Checkpoints

The checkpoints method saves the model by creating a dictionary that contains all the necessary information like model state_dict, optimizer state_dict, current epoch, loss, etc. And, to load the model, the checkpoint file is loaded to retrieve the information. This method is demonstrated as shown below:

Python




# Saving the model
checkpoint = {
    'epoch': epoch,
    'model_state_dict': cnn_model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss,
    # you may add other information to add
}
torch.save(checkpoint, 'checkpoint.pth')
 
# Loading the model
checkpoint = torch.load('checkpoint.pth')
cnn_model = SimpleCNN()
cnn_model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
print(cnn_model)


Output:

SimpleCNN(
(conv1_layer): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv2_layer): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(fc1_layer): Linear(in_features=1568, out_features=128, bias=True)
(fc2_layer): Linear(in_features=128, out_features=10, bias=True)
)

Conclusion

There are various methods to save and load Models created using PyTorch Library. It has the torch.save() and torch.load() method to save and load the model object. On the other hand, the model.state_dict() provides the memory-efficient approach to save and load the models. In addition to this, if you want to store all the relevant information about the model in a dictionary, you can use the checkpoint file to store the model object and retrieve it from the memory. Hence, these various methods allow us to manage the models, and transfer the parameters and other information. All we need to understand is the memory constraints, information beyond just model parameters, and use-case scenarios so that we can select the right method.

Frequently Asked Questions

Q. What is ‘.pth’ in the PyTorch Model File?

The ‘.pth’ file is the extension of the Model Object File in the PyTorch. The pth file includes the parameters such as weights and biases, etc. Its main use is that it stores the parameters along with the corresponding tensor value so that the model can be reconstructed with the same parameters.

Q. How to create a copy of the model in PyTorch?

To create a copy of the method, you can use the copy.deepcopy() method and assign the copy of the Model in another variable. Another way is to create an instance of the Model and then copy the parameters like weights and biases using the load_state_dict() method.

Q. How do I save a model file in Python?

In Python, you can save a model file using the torch.save() function provided by PyTorch. This function allows you to save the entire model (including its architecture and trained weights) to a file



Like Article
Suggest improvement
Share your thoughts in the comments

Similar Reads