Open In App

PyTorch vs PyTorch Lightning

Last Updated : 21 Mar, 2024
Improve
Improve
Like Article
Like
Save
Share
Report

The PyTorch research team at Facebook AI Research (FAIR) introduced PyTorch Lightning to address these challenges and provide a more organized and standardized approach. In this article, we will see the major differences between PyTorch Lightning and Pytorch.

Pytorch

PyTorch is widely used for deep learning and artificial intelligence research and applications. PyTorch provides a dynamic computational graph, allowing for more flexibility and ease of use compared to static computational graph frameworks.

Pytorch Lightning: Advanced Framework of Pytorch

PyTorch Lightning is a lightweight PyTorch wrapper that provides a high-level interface for training PyTorch models. It is designed to simplify and standardize the training loop, making it easier to write cleaner, more modular code for deep learning projects. PyTorch Lightning introduces a set of abstractions and conventions that remove boilerplate code and allow researchers and practitioners to focus more on the model architecture and experiment configurations.

Pytorch vs Pytorch Lightning

PyTorch and PyTorch Lightning are both frameworks for building and training neural network models, but they differ in terms of abstraction, structure, and ease of use. Here are some key differences between PyTorch and PyTorch Lightning:

Features

PyTorch

PyTorch Lightning

Training Loop

User-defined, training loop, validation loop, and testing loop explicitly, including handling aspects like moving data to the GPU, computing gradients, and updating model parameters.

Users define hooks and callbacks to customize behavior without directly modifying the training loop.

Model Setup

User define the model, loss function, optimizer, and other components explicitly.

Standardized with dedicated methods

Abstraction LevelLower-level, requires more manual codingHigher-level, hides boilerplate code

GPU and Distributed Training

Requires manual efforts for explicitly moving models and data to GPU, manage distributed training, and handle multi-GPU scenarios.

Automatic based on user configuration. Users can specify the number of GPUs.

Logging and Experiment Tracking

Logging metrics and tracking experiments require manual implementation using tools like TensorBoard or custom loggers.

built-in support for various logging frameworks (TensorBoard, CSV, etc.) and experiment tracking platforms (e.g., WandB, Comet)

Debugging and Profiling

Manual instrumentation

Provides hooks for common debugging and profiling tools.

Integration with other Libraries

User-managed integrations

Built-in integrations

Best Practices

Independent discovery

Benefits from a growing community and ecosystem, allowing users to leverage pre-built components and best practices.

Standardized Interface

Code structure may vary.

Enforces a consistent structure through LightningModule.

Module System

Left to developer discretion.

Promotes a modular system with LightningModules.

Checkpointing

Users must implement checkpointing logic.

Built-in support for simplified model checkpointing.

Implementation: From Pytorch to Pytorch Lightning

Let’s illustrate the difference in code between a basic PyTorch script and its equivalent using PyTorch Lightning. Consider a simple training script for a neural network in both PyTorch and PyTorch Lightning.

Let’s compare the training and validation loops for a simple 3-layer neural network on the MNIST dataset using both PyTorch and PyTorch Lightning. The key ingredients include the model, dataset (MNIST), optimizer, and loss function.

PyTorch

Importing necessary libraries and modules

The code starts by importing the necessary libraries and modules for building and training the neural network. These include torch, torch.nn, torch.optim, and torchvision.

Python3
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

2. Defining the neural network model

Next, the code defines a simple neural network model using PyTorch’s nn.Module class. The model consists of three fully connected layers (fc1, fc2, and fc3) with 256, 128, and 10 neurons, respectively. The output layer has 10 neurons, corresponding to the 10 classes of digits in the MNIST dataset.

The forward method defines the forward pass of the neural network, where the input is passed through each layer and transformed using the ReLU activation function.

Python3
# Define the model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.flatten(x, 1)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

3. Loading the dataset

The code then loads the MNIST dataset using torchvision.datasets.MNIST. The dataset is preprocessed using the transforms.Compose method, which applies a series of transformations to the data. In this case, the data is converted to a tensor and normalized to have a mean of 0.1307 and a standard deviation of 0.3081.

The train_dataset and test_dataset objects are created and loaded into train_loader and test_loader, which are PyTorch DataLoader objects that handle batching and shuffling of the data.

Python3
# Load the dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('data', train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

4. Initializing the model, optimizer, and loss function

The neural network model is initialized, and the stochastic gradient descent (SGD) optimizer and cross-entropy loss function are defined.

Python
# Initialize the model
model = NeuralNetwork()

# Define the optimizer and loss function
optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

5. Training and validating the model

The code defines two functions, train and validate, which handle the training and validation of the neural network.

  1. The train function takes in the model, training data loader, optimizer, and loss function, and trains the model on the data in batches. The gradients are accumulated and the model weights are updated using the SGD optimizer.
  2. The validate function takes in the model, test data loader, and loss function, and evaluates the model on the test data. The test loss is computed and the accuracy of the model is calculated as the percentage of correct predictions.

Finally, the code trains and validates the neural network for 10 epochs using the train and validate functions.

Python3
# Training loop
def train(model, train_loader, optimizer, criterion):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

# Validation loop
def validate(model, test_loader, criterion):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    print('Validation set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.format(
        test_loss, correct, len(test_loader.dataset), accuracy))
    
# Train and validate the model
for epoch in range(10):
    train(model, train_loader, optimizer, criterion)
    validate(model, test_loader, criterion)

Output:

Validation set: Average loss: 0.0052, Accuracy: 9050/10000 (90.50%)
Validation set: Average loss: 0.0041, Accuracy: 9267/10000 (92.67%)
Validation set: Average loss: 0.0034, Accuracy: 9369/10000 (93.69%)
Validation set: Average loss: 0.0030, Accuracy: 9445/10000 (94.45%)
Validation set: Average loss: 0.0026, Accuracy: 9506/10000 (95.06%)
Validation set: Average loss: 0.0024, Accuracy: 9555/10000 (95.55%)
Validation set: Average loss: 0.0021, Accuracy: 9597/10000 (95.97%)
Validation set: Average loss: 0.0020, Accuracy: 9641/10000 (96.41%)
Validation set: Average loss: 0.0018, Accuracy: 9663/10000 (96.63%)
Validation set: Average loss: 0.0017, Accuracy: 9680/10000 (96.80%)

PyTorch Lightning

1. First, the necessary imports are made, including PyTorch, PyTorch Lightning, the MNIST dataset, and the Adam optimizer.

Python3
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch.optim import Adam
import pytorch_lightning as pl


2. Defining the neural network model

Next, the MyModel class is defined, which inherits from pl.LightningModule. This class defines the neural network architecture, the forward pass, the training step, the validation step, and the configuration of the optimizer.

  • In the __init__ method, the neural network architecture is defined using PyTorch’s nn.Sequential module. It consists of three fully connected layers with ReLU activation functions, and a final softmax layer for outputting probabilities.
  • The forward method takes in an input tensor x, reshapes it to have the correct number of dimensions, and passes it through the neural network using the self.model module.
Python3
class MyModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 10),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Reshape the input
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.CrossEntropyLoss()(y_hat, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.CrossEntropyLoss()(y_hat, y)
        self.log('val_loss', loss)
        # Calculate accuracy
        correct = (y_hat.argmax(1) == y).sum().item()
        total = y.size(0)
        self.log('accuracy', correct / total, on_step=False, on_epoch=True, prog_bar=True)

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=0.001)

3. Loading the dataset

The MNIST dataset is loaded using MNIST class from torchvision.datasets. The training set and validation set are split into separate DataLoader objects for training and validation.

Python3
# Load the MNIST dataset
train_dataset = MNIST(root='.', train=True, transform=ToTensor(), download=True)
val_dataset = MNIST(root='.', train=False, transform=ToTensor())

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64)

4. Initializing the model

Python3
# Initialize the model
model = MyModel()

5. Train the model

The trainer object is created using pl.Trainer. The max_epochs argument is set to 10, which means that the model will be trained for 10 epochs. The accelerator argument is set to “cpu” if a GPU is not available, otherwise it is set to “cpu”.

Python3
# Initialize the trainer
trainer = pl.Trainer(max_epochs=10, accelerator = "cpu" if torch.cuda.is_available() else "cpu")


6. Model fitting

Finally, the model is trained using the trainer.fit method. It takes in the model object, and the train_loader and val_loader objects as arguments.

Python3
# Train the model
trainer.fit(model, train_loader, val_loader)


7. Validating the model

Python3
trainer.validate(model, val_loader)

Output:

[{'val_loss': 1.4881926774978638, 'accuracy': 0.9729999899864197}]

Code Difference Takeaways

FeaturePyTorchPyTorch Lightning
InheritanceInherits from nn.ModuleInherits from pl.LightningModule
Architecture DefinitionUses separate class or custom definitionUses nn.Sequential within MyModel class
Code StructureSeparate functions for training and validationOrganized within MyModel class with dedicated methods
Training LoopExplicitly written for loopAbstracted, handled by trainer.fit
Optimizer and SchedulerDefined and configured within training loopDefined in configure_optimizers method
Logging MetricsManual implementation with external librariesSimplified using self.log within LightningModule

Conclusion

PyTorch Lightning serves as a powerful tool for researchers and practitioners in the deep learning community, offering a standardized and organized framework for building and training models. By abstracting away common boilerplate code, automating training processes, and providing a modular structure, PyTorch Lightning simplifies the development workflow and enhances collaboration.



Like Article
Suggest improvement
Previous
Next
Share your thoughts in the comments

Similar Reads