Open In App

Generative Adversarial Networks (GANs) in PyTorch

The aim of the article is to implement GANs architecture using PyTorch framework. The article provides comprehensive understanding of GANs in PyTorch along with in-depth explanation of the code.

Generative Adversarial Networks (GANs) are a class of artificial intelligence algorithms used in unsupervised machine learning. They consist of two neural networks, the generator and the discriminator, which are trained simultaneously through a competitive process. The generator creates new data instances, while the discriminator evaluates whether they are real (from the true data distribution) or fake (produced by the generator). This adversarial training process leads to the improvement of both networks over time

Implementing GANs using PyTorch Framework

In this section, we are going to demonstrate the implementation of Generative Adversarial Network (GAN) architecture for generating realistic handwritten digits using the following steps:

Step 1: Importing Necessary Libraries

We will be importing fundamental pytorch libraries : torch and torch.nn, torch.optim for updating the parameters of the neural network. torchvision is utilized for loading and preprocessing the MNIST dataset, making it easier to work with image data in PyTorch and torchvision.transforms is used to define transformations for preprocessing the MNIST images before feeding them into the GAN.

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

Step 2: Define Generator Function

We have defined a generator class.

# Generator
class Generator(nn.Module):
def __init__(self, noise_dim):
super(Generator, self).__init__()
self.noise_dim = noise_dim
self.main = nn.Sequential(
nn.Linear(noise_dim, 7 * 7 * 256),
nn.ReLU(True),
nn.Unflatten(1, (256, 7, 7)),
nn.ConvTranspose2d(256, 128, 5, stride=1, padding=2),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, 64, 5, stride=2, padding=2, output_padding=1),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.ConvTranspose2d(64, 1, 5, stride=2, padding=2, output_padding=1),
nn.Tanh()
)

def forward(self, x):
return self.main(x)

Step 3: Define Discriminator Function

We have defined discriminator function.

# Discriminator
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(1, 64, 5, stride=2, padding=2),
nn.LeakyReLU(0.2, inplace=True),
nn.BatchNorm2d(64),
nn.Conv2d(64, 128, 5, stride=2, padding=2),
nn.LeakyReLU(0.2, inplace=True),
nn.BatchNorm2d(128),
nn.Flatten(),
nn.Linear(7 * 7 * 128, 1)
)

def forward(self, x):
return self.main(x)

Step 4: Combine the Generator and Discriminator Function

Here, an instance is created "generator" with specified noise vector. The generator will be responsible for generating fake images from random noise. Next, we have created another instance "discriminator" to distinguish between real and fake images.

# Noise dimension
NOISE_DIM = 100

# Generator and discriminator
generator = Generator(NOISE_DIM)
discriminator = Discriminator()

Step 5: Device Configuration

Device configuration allows for efficient training of the GAN models on the available hardware resources.

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
generator = generator.to(device)
discriminator = discriminator.to(device)

Step 6: Set Loss Function, Optimizer and Hyperparameters

In this section of the code ,we have used Binary Cross Entropy with Logits Loss as loss function, this function is used for binary classification and suits the problem to distinguish between real and fake images. We initialize two Adam optimizers, one for the generator (generator_optimizer) and one for the discriminator (discriminator_optimizer) with learning rate of 0.0002.

We set the number of epochs (NUM_EPOCHS) to 5 and the batch size (BATCH_SIZE) to 256. These hyperparameters determine the number of iterations and the size of the data batches used for training the GAN.

# Loss function
criterion = nn.BCEWithLogitsLoss()

# Optimizers
generator_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Training hyperparameters
NUM_EPOCHS = 5
BATCH_SIZE = 256

Step 7: DataLoader

This section of the code prepares the MNIST dataset for training the GAN:

# DataLoader
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

Step 8: Training Process

This training loop iterates over the specified number of epochs, training the GAN by alternating between updating the discriminator and the generator:

# Training loop
for epoch in range(NUM_EPOCHS):
for i, data in enumerate(train_loader):
real_images, _ = data
real_images = real_images.to(device)

# Train discriminator with real images
discriminator_optimizer.zero_grad()
real_labels = torch.ones(real_images.size(0), 1, device=device)
real_outputs = discriminator(real_images)
real_loss = criterion(real_outputs, real_labels)
real_loss.backward()

# Train discriminator with fake images
noise = torch.randn(real_images.size(0), NOISE_DIM, device=device)
fake_images = generator(noise)
fake_labels = torch.zeros(real_images.size(0), 1, device=device)
fake_outputs = discriminator(fake_images.detach())
fake_loss = criterion(fake_outputs, fake_labels)
fake_loss.backward()
discriminator_optimizer.step()

# Train generator
generator_optimizer.zero_grad()
fake_labels = torch.ones(real_images.size(0), 1, device=device)
fake_outputs = discriminator(fake_images)
gen_loss = criterion(fake_outputs, fake_labels)
gen_loss.backward()
generator_optimizer.step()

# Print losses
if i % 100 == 0:
print(f'Epoch [{epoch+1}/{NUM_EPOCHS}], Step [{i+1}/{len(train_loader)}], '
f'Discriminator Loss: {real_loss.item() + fake_loss.item():.4f}, '
f'Generator Loss: {gen_loss.item():.4f}')

Step 9: Visualization

Now, we have defined generate_and_save_images to generate fake images using the trained generator model and save them to files:

# Generate and save images
def generate_and_save_images(model, epoch, noise):
model.eval()
with torch.no_grad():
fake_images = model(noise).cpu()
fake_images = fake_images.view(fake_images.size(0), 28, 28)

fig = plt.figure(figsize=(4, 4))
for i in range(fake_images.size(0)):
plt.subplot(4, 4, i+1)
plt.imshow(fake_images[i], cmap='gray')
plt.axis('off')

plt.savefig(f'image_at_epoch_{epoch+1:04d}.png')
plt.show()

# Generate test noise
test_noise = torch.randn(16, NOISE_DIM, device=device)
generate_and_save_images(generator, NUM_EPOCHS, test_noise)

Complete Code and Output:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

# Generator
class Generator(nn.Module):
    def __init__(self, noise_dim):
        super(Generator, self).__init__()
        self.noise_dim = noise_dim
        self.main = nn.Sequential(
            nn.Linear(noise_dim, 7 * 7 * 256),
            nn.ReLU(True),
            nn.Unflatten(1, (256, 7, 7)),
            nn.ConvTranspose2d(256, 128, 5, stride=1, padding=2),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 1, 5, stride=2, padding=2, output_padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        return self.main(x)


# Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(1, 64, 5, stride=2, padding=2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 128, 5, stride=2, padding=2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm2d(128),
            nn.Flatten(),
            nn.Linear(7 * 7 * 128, 1)
        )

    def forward(self, x):
        return self.main(x)


# Noise dimension
NOISE_DIM = 100

# Generator and discriminator
generator = Generator(NOISE_DIM)
discriminator = Discriminator()

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
generator = generator.to(device)
discriminator = discriminator.to(device)

# Loss function
criterion = nn.BCEWithLogitsLoss()

# Optimizers
generator_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Training hyperparameters
NUM_EPOCHS = 5
BATCH_SIZE = 256

# DataLoader
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Training loop
for epoch in range(NUM_EPOCHS):
    for i, data in enumerate(train_loader):
        real_images, _ = data
        real_images = real_images.to(device)

        # Train discriminator with real images
        discriminator_optimizer.zero_grad()
        real_labels = torch.ones(real_images.size(0), 1, device=device)
        real_outputs = discriminator(real_images)
        real_loss = criterion(real_outputs, real_labels)
        real_loss.backward()

        # Train discriminator with fake images
        noise = torch.randn(real_images.size(0), NOISE_DIM, device=device)
        fake_images = generator(noise)
        fake_labels = torch.zeros(real_images.size(0), 1, device=device)
        fake_outputs = discriminator(fake_images.detach())
        fake_loss = criterion(fake_outputs, fake_labels)
        fake_loss.backward()
        discriminator_optimizer.step()

        # Train generator
        generator_optimizer.zero_grad()
        fake_labels = torch.ones(real_images.size(0), 1, device=device)
        fake_outputs = discriminator(fake_images)
        gen_loss = criterion(fake_outputs, fake_labels)
        gen_loss.backward()
        generator_optimizer.step()

        # Print losses
        if i % 100 == 0:
            print(f'Epoch [{epoch+1}/{NUM_EPOCHS}], Step [{i+1}/{len(train_loader)}], '
                  f'Discriminator Loss: {real_loss.item() + fake_loss.item():.4f}, '
                  f'Generator Loss: {gen_loss.item():.4f}')

# Generate and save images
def generate_and_save_images(model, epoch, noise):
    model.eval()
    with torch.no_grad():
        fake_images = model(noise).cpu()
        fake_images = fake_images.view(fake_images.size(0), 28, 28)

        fig = plt.figure(figsize=(4, 4))
        for i in range(fake_images.size(0)):
            plt.subplot(4, 4, i+1)
            plt.imshow(fake_images[i], cmap='gray')
            plt.axis('off')

        plt.savefig(f'image_at_epoch_{epoch+1:04d}.png')
        plt.show()

# Generate test noise
test_noise = torch.randn(16, NOISE_DIM, device=device)
generate_and_save_images(generator, NUM_EPOCHS, test_noise)

Output:

Epoch [1/5], Step [1/235], Discriminator Loss: 1.6305, Generator Loss: 1.0509
Epoch [1/5], Step [101/235], Discriminator Loss: 0.2560, Generator Loss: 4.2435
Epoch [1/5], Step [201/235], Discriminator Loss: 0.2019, Generator Loss: 5.7860
Epoch [2/5], Step [1/235], Discriminator Loss: 0.0429, Generator Loss: 4.2411
Epoch [2/5], Step [101/235], Discriminator Loss: 0.0505, Generator Loss: 4.4958
Epoch [2/5], Step [201/235], Discriminator Loss: 0.0449, Generator Loss: 4.6327
Epoch [3/5], Step [1/235], Discriminator Loss: 0.0257, Generator Loss: 5.1921
Epoch [3/5], Step [101/235], Discriminator Loss: 0.0354, Generator Loss: 5.5234
Epoch [3/5], Step [201/235], Discriminator Loss: 0.0290, Generator Loss: 5.2325
Epoch [4/5], Step [1/235], Discriminator Loss: 0.0104, Generator Loss: 5.6811
Epoch [4/5], Step [101/235], Discriminator Loss: 0.0097, Generator Loss: 5.6416
Epoch [4/5], Step [201/235], Discriminator Loss: 0.0030, Generator Loss: 6.3280
Epoch [5/5], Step [1/235], Discriminator Loss: 0.0079, Generator Loss: 5.6755
Epoch [5/5], Step [101/235], Discriminator Loss: 0.0097, Generator Loss: 5.9742
Epoch [5/5], Step [201/235], Discriminator Loss: 0.0055, Generator Loss: 6.0514

The output of the image is not clear as the image is trained only for 5 epochs, you can train the image for more number of epochs to get better results.

download

Article Tags :