Open In App

Generative Adversarial Networks (GANs) in PyTorch

Last Updated : 28 Mar, 2024
Improve
Improve
Like Article
Like
Save
Share
Report

The aim of the article is to implement GANs architecture using PyTorch framework. The article provides comprehensive understanding of GANs in PyTorch along with in-depth explanation of the code.

Generative Adversarial Networks (GANs) are a class of artificial intelligence algorithms used in unsupervised machine learning. They consist of two neural networks, the generator and the discriminator, which are trained simultaneously through a competitive process. The generator creates new data instances, while the discriminator evaluates whether they are real (from the true data distribution) or fake (produced by the generator). This adversarial training process leads to the improvement of both networks over time

Implementing GANs using PyTorch Framework

In this section, we are going to demonstrate the implementation of Generative Adversarial Network (GAN) architecture for generating realistic handwritten digits using the following steps:

Step 1: Importing Necessary Libraries

We will be importing fundamental pytorch libraries : torch and torch.nn, torch.optim for updating the parameters of the neural network. torchvision is utilized for loading and preprocessing the MNIST dataset, making it easier to work with image data in PyTorch and torchvision.transforms is used to define transformations for preprocessing the MNIST images before feeding them into the GAN.

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

Step 2: Define Generator Function

We have defined a generator class.

  • Initialization: Inherits from nn.Module and takes a parameter noise_dim, representing the dimensionality of the input noise vector. The main architecture is defined within this method.
  • Architecture: Utilizes a sequential neural network (self.main) consisting of linear, ReLU activation, unflatten, and convolutional transpose layers. These layers progressively upsample the input noise vector to generate a 28×28 grayscale image resembling handwritten digits.
  • Output Layer: The final layer applies a Tanh activation function to squish the pixel values of the output image to the range [-1, 1], making it suitable for real-valued image data.
  • Forward Method: Implements the forward pass of the generator. It takes an input noise vector (x) and passes it through the sequential model (self.main) to generate the output image.
# Generator
class Generator(nn.Module):
def __init__(self, noise_dim):
super(Generator, self).__init__()
self.noise_dim = noise_dim
self.main = nn.Sequential(
nn.Linear(noise_dim, 7 * 7 * 256),
nn.ReLU(True),
nn.Unflatten(1, (256, 7, 7)),
nn.ConvTranspose2d(256, 128, 5, stride=1, padding=2),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, 64, 5, stride=2, padding=2, output_padding=1),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.ConvTranspose2d(64, 1, 5, stride=2, padding=2, output_padding=1),
nn.Tanh()
)

def forward(self, x):
return self.main(x)

Step 3: Define Discriminator Function

We have defined discriminator function.

  • Initialization: Inherits from nn.Module. The discriminator is designed without any input parameters.
  • Architecture: Utilizes a sequential neural network (self.main) comprising convolutional layers with LeakyReLU activations and batch normalization. These layers progressively downsample the input image to a single scalar output, determining the likelihood that the input image is real.
  • Output Layer: The final layer is a fully connected linear layer, producing a single scalar output representing the discriminator’s decision on the input image’s authenticity.
  • Forward Method: Implements the forward pass of the discriminator. It takes an input image (x) and passes it through the sequential model (self.main) to compute the discriminator’s output.
# Discriminator
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(1, 64, 5, stride=2, padding=2),
nn.LeakyReLU(0.2, inplace=True),
nn.BatchNorm2d(64),
nn.Conv2d(64, 128, 5, stride=2, padding=2),
nn.LeakyReLU(0.2, inplace=True),
nn.BatchNorm2d(128),
nn.Flatten(),
nn.Linear(7 * 7 * 128, 1)
)

def forward(self, x):
return self.main(x)

Step 4: Combine the Generator and Discriminator Function

Here, an instance is created “generator” with specified noise vector. The generator will be responsible for generating fake images from random noise. Next, we have created another instance “discriminator” to distinguish between real and fake images.

# Noise dimension
NOISE_DIM = 100

# Generator and discriminator
generator = Generator(NOISE_DIM)
discriminator = Discriminator()

Step 5: Device Configuration

Device configuration allows for efficient training of the GAN models on the available hardware resources.

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
generator = generator.to(device)
discriminator = discriminator.to(device)

Step 6: Set Loss Function, Optimizer and Hyperparameters

In this section of the code ,we have used Binary Cross Entropy with Logits Loss as loss function, this function is used for binary classification and suits the problem to distinguish between real and fake images. We initialize two Adam optimizers, one for the generator (generator_optimizer) and one for the discriminator (discriminator_optimizer) with learning rate of 0.0002.

We set the number of epochs (NUM_EPOCHS) to 5 and the batch size (BATCH_SIZE) to 256. These hyperparameters determine the number of iterations and the size of the data batches used for training the GAN.

# Loss function
criterion = nn.BCEWithLogitsLoss()

# Optimizers
generator_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Training hyperparameters
NUM_EPOCHS = 5
BATCH_SIZE = 256

Step 7: DataLoader

This section of the code prepares the MNIST dataset for training the GAN:

  • Transformations: Images are transformed into tensors and normalized to range [-1, 1].
  • Dataset: MNIST training dataset is loaded with specified transformations and downloaded if necessary.
  • DataLoader: Creates batches of data, shuffles them, and handles loading them during training.
# DataLoader
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

Step 8: Training Process

This training loop iterates over the specified number of epochs, training the GAN by alternating between updating the discriminator and the generator:

  • For each epoch, it iterates through batches of real images from the DataLoader.
  • It trains the discriminator with real images by computing the loss based on real and fake labels, then updates the discriminator’s parameters.
  • Next, it generates fake images using random noise and trains the discriminator with them, updating its parameters accordingly.
  • Finally, it trains the generator by generating fake images and computing the loss based on discriminator feedback, updating the generator’s parameters.
  • Losses are printed periodically to monitor training progress.
# Training loop
for epoch in range(NUM_EPOCHS):
for i, data in enumerate(train_loader):
real_images, _ = data
real_images = real_images.to(device)

# Train discriminator with real images
discriminator_optimizer.zero_grad()
real_labels = torch.ones(real_images.size(0), 1, device=device)
real_outputs = discriminator(real_images)
real_loss = criterion(real_outputs, real_labels)
real_loss.backward()

# Train discriminator with fake images
noise = torch.randn(real_images.size(0), NOISE_DIM, device=device)
fake_images = generator(noise)
fake_labels = torch.zeros(real_images.size(0), 1, device=device)
fake_outputs = discriminator(fake_images.detach())
fake_loss = criterion(fake_outputs, fake_labels)
fake_loss.backward()
discriminator_optimizer.step()

# Train generator
generator_optimizer.zero_grad()
fake_labels = torch.ones(real_images.size(0), 1, device=device)
fake_outputs = discriminator(fake_images)
gen_loss = criterion(fake_outputs, fake_labels)
gen_loss.backward()
generator_optimizer.step()

# Print losses
if i % 100 == 0:
print(f'Epoch [{epoch+1}/{NUM_EPOCHS}], Step [{i+1}/{len(train_loader)}], '
f'Discriminator Loss: {real_loss.item() + fake_loss.item():.4f}, '
f'Generator Loss: {gen_loss.item():.4f}')

Step 9: Visualization

Now, we have defined generate_and_save_images to generate fake images using the trained generator model and save them to files:

  • It sets the generator to evaluation mode and generates fake images from the given noise vector.
  • The generated images are reshaped and plotted in a grid using Matplotlib.
  • The function saves the generated images to files named with the epoch number and displays the images.
  • Finally, it generates test noise and calls the function to create and save fake images using the trained generator model.
# Generate and save images
def generate_and_save_images(model, epoch, noise):
model.eval()
with torch.no_grad():
fake_images = model(noise).cpu()
fake_images = fake_images.view(fake_images.size(0), 28, 28)

fig = plt.figure(figsize=(4, 4))
for i in range(fake_images.size(0)):
plt.subplot(4, 4, i+1)
plt.imshow(fake_images[i], cmap='gray')
plt.axis('off')

plt.savefig(f'image_at_epoch_{epoch+1:04d}.png')
plt.show()

# Generate test noise
test_noise = torch.randn(16, NOISE_DIM, device=device)
generate_and_save_images(generator, NUM_EPOCHS, test_noise)

Complete Code and Output:

Python3
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

# Generator
class Generator(nn.Module):
    def __init__(self, noise_dim):
        super(Generator, self).__init__()
        self.noise_dim = noise_dim
        self.main = nn.Sequential(
            nn.Linear(noise_dim, 7 * 7 * 256),
            nn.ReLU(True),
            nn.Unflatten(1, (256, 7, 7)),
            nn.ConvTranspose2d(256, 128, 5, stride=1, padding=2),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 1, 5, stride=2, padding=2, output_padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        return self.main(x)


# Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(1, 64, 5, stride=2, padding=2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 128, 5, stride=2, padding=2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm2d(128),
            nn.Flatten(),
            nn.Linear(7 * 7 * 128, 1)
        )

    def forward(self, x):
        return self.main(x)


# Noise dimension
NOISE_DIM = 100

# Generator and discriminator
generator = Generator(NOISE_DIM)
discriminator = Discriminator()

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
generator = generator.to(device)
discriminator = discriminator.to(device)

# Loss function
criterion = nn.BCEWithLogitsLoss()

# Optimizers
generator_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Training hyperparameters
NUM_EPOCHS = 5
BATCH_SIZE = 256

# DataLoader
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Training loop
for epoch in range(NUM_EPOCHS):
    for i, data in enumerate(train_loader):
        real_images, _ = data
        real_images = real_images.to(device)

        # Train discriminator with real images
        discriminator_optimizer.zero_grad()
        real_labels = torch.ones(real_images.size(0), 1, device=device)
        real_outputs = discriminator(real_images)
        real_loss = criterion(real_outputs, real_labels)
        real_loss.backward()

        # Train discriminator with fake images
        noise = torch.randn(real_images.size(0), NOISE_DIM, device=device)
        fake_images = generator(noise)
        fake_labels = torch.zeros(real_images.size(0), 1, device=device)
        fake_outputs = discriminator(fake_images.detach())
        fake_loss = criterion(fake_outputs, fake_labels)
        fake_loss.backward()
        discriminator_optimizer.step()

        # Train generator
        generator_optimizer.zero_grad()
        fake_labels = torch.ones(real_images.size(0), 1, device=device)
        fake_outputs = discriminator(fake_images)
        gen_loss = criterion(fake_outputs, fake_labels)
        gen_loss.backward()
        generator_optimizer.step()

        # Print losses
        if i % 100 == 0:
            print(f'Epoch [{epoch+1}/{NUM_EPOCHS}], Step [{i+1}/{len(train_loader)}], '
                  f'Discriminator Loss: {real_loss.item() + fake_loss.item():.4f}, '
                  f'Generator Loss: {gen_loss.item():.4f}')

# Generate and save images
def generate_and_save_images(model, epoch, noise):
    model.eval()
    with torch.no_grad():
        fake_images = model(noise).cpu()
        fake_images = fake_images.view(fake_images.size(0), 28, 28)

        fig = plt.figure(figsize=(4, 4))
        for i in range(fake_images.size(0)):
            plt.subplot(4, 4, i+1)
            plt.imshow(fake_images[i], cmap='gray')
            plt.axis('off')

        plt.savefig(f'image_at_epoch_{epoch+1:04d}.png')
        plt.show()

# Generate test noise
test_noise = torch.randn(16, NOISE_DIM, device=device)
generate_and_save_images(generator, NUM_EPOCHS, test_noise)

Output:

Epoch [1/5], Step [1/235], Discriminator Loss: 1.6305, Generator Loss: 1.0509
Epoch [1/5], Step [101/235], Discriminator Loss: 0.2560, Generator Loss: 4.2435
Epoch [1/5], Step [201/235], Discriminator Loss: 0.2019, Generator Loss: 5.7860
Epoch [2/5], Step [1/235], Discriminator Loss: 0.0429, Generator Loss: 4.2411
Epoch [2/5], Step [101/235], Discriminator Loss: 0.0505, Generator Loss: 4.4958
Epoch [2/5], Step [201/235], Discriminator Loss: 0.0449, Generator Loss: 4.6327
Epoch [3/5], Step [1/235], Discriminator Loss: 0.0257, Generator Loss: 5.1921
Epoch [3/5], Step [101/235], Discriminator Loss: 0.0354, Generator Loss: 5.5234
Epoch [3/5], Step [201/235], Discriminator Loss: 0.0290, Generator Loss: 5.2325
Epoch [4/5], Step [1/235], Discriminator Loss: 0.0104, Generator Loss: 5.6811
Epoch [4/5], Step [101/235], Discriminator Loss: 0.0097, Generator Loss: 5.6416
Epoch [4/5], Step [201/235], Discriminator Loss: 0.0030, Generator Loss: 6.3280
Epoch [5/5], Step [1/235], Discriminator Loss: 0.0079, Generator Loss: 5.6755
Epoch [5/5], Step [101/235], Discriminator Loss: 0.0097, Generator Loss: 5.9742
Epoch [5/5], Step [201/235], Discriminator Loss: 0.0055, Generator Loss: 6.0514

The output of the image is not clear as the image is trained only for 5 epochs, you can train the image for more number of epochs to get better results.

download



Like Article
Suggest improvement
Share your thoughts in the comments

Similar Reads