Open In App
Related Articles

Generative Adversarial Network (GAN)

Improve Article
Improve
Save Article
Save
Like Article
Like

Generative Adversarial Networks, or GANs, represent a cutting-edge approach to generative modeling within deep learning, often leveraging architectures like convolutional neural networks. The goal of generative modeling is to autonomously identify patterns in input data, enabling the model to produce new examples that feasibly resemble the original dataset.

GANs tackle this challenge through a unique setup, treating it as a supervised learning problem involving two key components: the generator, which learns to produce novel examples, and the discriminator, tasked with distinguishing between genuine and generated instances. Through adversarial training, these models engage in a competitive interplay until the generator becomes adept at creating realistic samples, fooling the discriminator approximately half the time.

This dynamic field of GANs has rapidly evolved, showcasing remarkable capabilities in generating lifelike content across various domains. Notable applications include image-to-image translation tasks and the creation of photorealistic images indistinguishable from real photos, demonstrating the transformative potential of GANs in the realm of generative modeling.

What is a Generative Adversarial Network?

Generative Adversarial Networks (GANs) are a powerful class of neural networks that are used for unsupervised learning. GANs are made up of two neural networks, a discriminator and a generator. They use adversarial training to produce artificial data that is identical to actual data. The Generator attempts to fool the Discriminator, which is tasked with accurately distinguishing between produced and genuine data, by producing random noise samples. Realistic, high-quality samples are produced as a result of this competitive interaction, which drives both networks toward advancement. GANs are proving to be highly versatile artificial intelligence tools, as evidenced by their extensive use in image synthesis, style transfer, and text-to-image synthesis. They have also revolutionized generative modeling.

Generative Adversarial Networks (GANs) can be broken down into three parts:

  • Generative: To learn a generative model, which describes how data is generated in terms of a probabilistic model.
  • Adversarial: The word adversarial refers to setting one thing up against another. This means that, in the context of GANs, the generative result is compared with the actual images in the data set. A mechanism known as a discriminator is used to apply a model that attempts to distinguish between real and fake images.
  • Networks: Use deep neural networks as artificial intelligence (AI) algorithms for training purposes.

Architecture of GAN

A Generative Adversarial Network (GAN) is composed of two primary parts, which are the Generator and the Discriminator.

Generator Model

A key element responsible for creating fresh, accurate data in a Generative Adversarial Network (GAN) is the generator model. The generator takes random noise as input and converts it into complex data samples, such text or images. It is commonly depicted as a deep neural network. The training data’s underlying distribution is captured by layers of learnable parameters in its design through training. The generator adjusts its output to produce samples that closely mimic real data as it is being trained by using backpropagation to fine-tune its parameters. The generator’s ability to generate high-quality, varied samples that can fool the discriminator is what makes it successful.

Generator Loss(JG )

For generated samples, the generator minimizes the log likelihood that the discriminator is right. Due to this loss, the generator is incentivized to generate samples that the discriminator is likely to classify as real (logD(G(z i )) close to 1).
J_{G} = -\frac{1}{m} \Sigma^m _{i=1} log D(G(z_{i}))
Where,

  • JG measure how well the generator is fooling the discriminator.
  • log D(G(zi ) represents log probability of the discriminator being correct for generated samples.
  • The generator aims to minimize this loss, encouraging the production of samples that the discriminator classifies as real (log D(G(zi )) close to 1).

Discriminator Model

An artificial neural network called a discriminator model is used in Generative Adversarial Networks (GANs) to differentiate between generated and actual input. By evaluating input samples and allocating probability of authenticity, the discriminator functions as a binary classifier. Over time, the discriminator learns to differentiate between genuine data from the dataset and artificial samples created by the generator. This allows it to progressively hone its parameters and increase its level of proficiency. Convolutional layers or pertinent structures for other modalities are usually used in its architecture when dealing with picture data. Maximizing the discriminator’s capacity to accurately identify generated samples as fraudulent and real samples as authentic is the aim of the adversarial training procedure. The discriminator grows increasingly discriminating as a result of the generator and discriminator’s interaction, which helps the GAN produce extremely realistic-looking synthetic data overall.

Discriminator Loss(JD )

The discriminator reduces the negative log likelihood of correctly classifying both produced and real samples. This loss incentivizes the discriminator to accurately categorize generated samples as fake (log(1−D(G(zi)​)) close to 1) and real samples (log D(xi ) close to 1 ).
J_{D} = -\frac{1}{m} \Sigma_{i=1}^m log\; D(x_{i}) - \frac{1}{m}\Sigma_{i=1}^m log(1 - D(G(z_{i}))

  • JD assesses the discriminator’s ability to discern between produced and actual samples.
  • The log likelihood that the discriminator will accurately categorize real data is represented by logD(xi ​).
  • The log chance that the discriminator would correctly categorize generated samples as fake is represented by log⁡(1-D(G(zi))).
  • The discriminator aims to reduce this loss by accurately identifying artificial and real samples.

MinMax Loss

In a Generative Adversarial Network (GAN), the minimax loss formula is provided by:

min_{G}\;max_{D}(G,D) = [\mathbb{E}_{x∼p_{data}}[log\;D(x)] + \mathbb{E}_{z∼p_{z}(z)}[log(1 - D(g(z)))]
Where,

  • G is generator network and is D is the discriminator network
  • Actual data samples obtained from the true data distribution p_{data}(x)   are represented by x.
  • Random noise sampled from a previous distribution pz (z) (usually a normal or uniform distribution) is represented by z.
  • D(x) represents the discriminator’s likelihood of correctly identifying actual data as real.
  • D(G(z)) is the likelihood that the discriminator will identify generated data coming from the generator as authentic.

gans_gfg-(1)

How does a GAN work?

A Generative Adversarial Network (GAN) is a framework made up of two neural networks that have undergone simultaneous adversarial training—a discriminator and a generator. The discriminator separates generated data from real data, while the generator produces synthetic data that attempts to imitate real data. Training makes the generator more adept at producing realistic samples in an effort to trick the discriminator, which strengthens the generator’s discriminating abilities. GANs are an effective tool for producing realistic, high-quality outputs in a variety of fields, including text and image generation, because of this back-and-forth competition, which results in the creation of increasingly convincing and indistinguishable synthetic data.

Different Types of GAN Models

  1. Vanilla GAN: This is the simplest type of GAN. Here, the Generator and the Discriminator are simple multi-layer perceptrons. In vanilla GAN, the algorithm is really simple, it tries to optimize the mathematical equation using stochastic gradient descent.
  2. Conditional GAN (CGAN): CGAN can be described as a deep learning method in which some conditional parameters are put into place. In CGAN, an additional parameter ‘y’ is added to the Generator for generating the corresponding data. Labels are also put into the input to the Discriminator in order for the Discriminator to help distinguish the real data from the fake generated data.
  3. Deep Convolutional GAN (DCGAN): DCGAN is one of the most popular and also the most successful implementations of GAN. It is composed of ConvNets in place of multi-layer perceptrons. The ConvNets are implemented without max pooling, which is in fact replaced by convolutional stride. Also, the layers are not fully connected.
  4. Laplacian Pyramid GAN (LAPGAN): The Laplacian pyramid is a linear invertible image representation consisting of a set of band-pass images, spaced an octave apart, plus a low-frequency residual. This approach uses multiple numbers of Generator and Discriminator networks and different levels of the Laplacian Pyramid. This approach is mainly used because it produces very high-quality images. The image is down-sampled at first at each layer of the pyramid and then it is again up-scaled at each layer in a backward pass where the image acquires some noise from the Conditional GAN at these layers until it reaches its original size.
  5. Super Resolution GAN (SRGAN): SRGAN as the name suggests is a way of designing a GAN in which a deep neural network is used along with an adversarial network in order to produce higher-resolution images. This type of GAN is particularly useful in optimally up-scaling native low-resolution images to enhance their details minimizing errors while doing so.

Implementation of A GAN(Generative Adversarial Network)

Importing the required libraries

Python3




import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
 
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


For training on the CIFAR-10 image dataset, this PyTorch module creates a Generative Adversarial Network (GAN), switching between generator and discriminator training. Visualization of the generated images occurs every tenth epoch, and the development of the GAN is tracked.

Defining a Transform

Python3




# Define a basic transform
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])


This code uses PyTorch’s transforms to define a simple picture transforrms.Compose. It normalizes and transforms photos into tensors.

Loading the Dataset

Python3




train_dataset = datasets.CIFAR10(root='./data',\
              train=True, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(train_dataset, \
                                batch_size=32, shuffle=True)


A CIFAR-10 dataset is created for training by this code, which also specifies a root directory, turns on train mode, downloads if needed, and applies the specified transform. Subsequently, it generates a 32-batch DataLoader and shuffles the training set of data.

Defining parameters to be used in later processes

Python3




# Hyperparameters
latent_dim = 100
lr = 0.0002
beta1 = 0.5
beta2 = 0.999
num_epochs = 10


A Generative Adversarial Network (GAN) has these hyperparameters. The latent space’s dimensionality is represented by latent_dim. lr is the optimizer’s learning rate. The coefficients for the Adam optimizer are beta1 and beta2. To find the total number of training epochs, use num_epochs.

Defining a Utility Class to Build the Generator

Python3




# Define the generator
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
 
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128 * 8 * 8),
            nn.ReLU(),
            nn.Unflatten(1, (128, 8, 8)),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128, momentum=0.78),
            nn.ReLU(),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64, momentum=0.78),
            nn.ReLU(),
            nn.Conv2d(64, 3, kernel_size=3, padding=1),
            nn.Tanh()
        )
 
    def forward(self, z):
        img = self.model(z)
        return img


The generator architecture for a GAN in PyTorch is defined by this code. From nn.Module, the Generator class inherits. It is comprised of a sequential model with Tanh, linear, convolutional, batch normalization, reshaping, and upsampling layers. The neural network synthesizes an image (img) from a latent vector (z), which is the generator’s output. The architecture uses a series of learned transformations to turn the initial random noise in the latent space into a meaningful image.

Defining a Utility Class to Build the Discriminator

Python3




# Define the discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
 
        self.model = nn.Sequential(
        nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1),
        nn.LeakyReLU(0.2),
        nn.Dropout(0.25),
        nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
        nn.ZeroPad2d((0, 1, 0, 1)),
        nn.BatchNorm2d(64, momentum=0.82),
        nn.LeakyReLU(0.25),
        nn.Dropout(0.25),
        nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
        nn.BatchNorm2d(128, momentum=0.82),
        nn.LeakyReLU(0.2),
        nn.Dropout(0.25),
        nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(256, momentum=0.8),
        nn.LeakyReLU(0.25),
        nn.Dropout(0.25),
        nn.Flatten(),
        nn.Linear(256 * 5 * 5, 1),
        nn.Sigmoid()
    )
 
    def forward(self, img):
        validity = self.model(img)
        return validity


This PyTorch code describes the discriminator architecture for a GAN. The class Discriminator is descended from nn.Module. It is composed of linear layers, batch normalization, dropout, convolutional, LeakyReLU, and sequential layers. An image (img) is the discriminator’s input, and its validity—the probability that the input image is real as opposed to artificial—is its output.

Building the Generative Adversarial Network

Python3




# Define the generator and discriminator
# Initialize generator and discriminator
generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)
 
# Loss function
adversarial_loss = nn.BCELoss()
 
# Optimizers
optimizer_G = optim.Adam(generator.parameters()\
                         , lr=lr, betas=(beta1, beta2))
optimizer_D = optim.Adam(discriminator.parameters()\
                         , lr=lr, betas=(beta1, beta2))


This code snippet defines and initializes a discriminator (Discriminator) and a generator (Generator). The designated device (GPU if available) receives both models. Binary Cross Entropy Loss, which is frequently used for GANs, is selected as the loss function (adversarial_loss). For the generator (optimizer_G) and discriminator (optimizer_D), distinct Adam optimizers with predetermined learning rates and betas are also defined.

Training the Generative Adversarial Network

Python3




# Training loop
for epoch in range(num_epochs):
    for i, batch in enumerate(dataloader):
       # Convert list to tensor
        real_images = batch[0].to(device)
 
        # Adversarial ground truths
        valid = torch.ones(real_images.size(0), 1, device=device)
        fake = torch.zeros(real_images.size(0), 1, device=device)
 
        # Configure input
        real_images = real_images.to(device)
 
        # ---------------------
        #  Train Discriminator
        # ---------------------
 
        optimizer_D.zero_grad()
 
        # Sample noise as generator input
        z = torch.randn(real_images.size(0), latent_dim, device=device)
 
        # Generate a batch of images
        fake_images = generator(z)
 
        # Measure discriminator's ability
        # to classify real and fake images
        real_loss = adversarial_loss(discriminator\
                                     (real_images), valid)
        fake_loss = adversarial_loss(discriminator\
                                     (fake_images.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2
 
        # Backward pass and optimize
        d_loss.backward()
        optimizer_D.step()
 
        # -----------------
        #  Train Generator
        # -----------------
 
        optimizer_G.zero_grad()
 
        # Generate a batch of images
        gen_images = generator(z)
 
        # Adversarial loss
        g_loss = adversarial_loss(discriminator(gen_images), valid)
 
        # Backward pass and optimize
        g_loss.backward()
        optimizer_G.step()
 
        # ---------------------
        #  Progress Monitoring
        # ---------------------
 
        if (i + 1) % 100 == 0:
            print(
                f"Epoch [{epoch+1}/{num_epochs}]\
                        Batch {i+1}/{len(dataloader)} "
                f"Discriminator Loss: {d_loss.item():.4f} "
                f"Generator Loss: {g_loss.item():.4f}"
            )
 
    # Save generated images for every epoch
    if (epoch + 1) % 10 == 0:
        with torch.no_grad():
            z = torch.randn(16, latent_dim, device=device)
            generated = generator(z).detach().cpu()
            grid = torchvision.utils.make_grid(generated,\
                                        nrow=4, normalize=True)
            plt.imshow(np.transpose(grid, (1, 2, 0)))
            plt.axis("off")
            plt.show()


Output:

Epoch [10/10]                        Batch 1300/1563 Discriminator Loss: 0.4473 Generator Loss: 0.9555
Epoch [10/10] Batch 1400/1563 Discriminator Loss: 0.6643 Generator Loss: 1.0215
Epoch [10/10] Batch 1500/1563 Discriminator Loss: 0.4720 Generator Loss: 2.5027

gan-Geeksforgeeks

GAN Output

For a Generative Adversarial Network (GAN), this code implements the training loop. The training data batches are iterated through during each epoch. Whereas the generator (optimizer_G) is trained to generate realistic images that trick the discriminator, the discriminator (optimizer_D) is trained to distinguish between real and phony images. The generator and discriminator’s adversarial losses are computed. Model parameters are updated by means of Adam optimizers and the losses are backpropagated. Discriminator printing and generator losses are used to track progress. For a visual assessment of the training process, generated images are additionally saved and shown every 10 epochs.

Application Of Generative Adversarial Networks (GANs)

GANs, or Generative Adversarial Networks, have many uses in many different fields. Here are some of the widely recognized uses of GANs:

  1. Image Synthesis and Generation : GANs are often used for picture synthesis and generation tasks,  They may create fresh, lifelike pictures that mimic training data by learning the distribution that explains the dataset. The development of lifelike avatars, high-resolution photographs, and fresh artwork have all been facilitated by these types of generative networks.
  2. Image-to-Image Translation : GANs may be used for problems involving image-to-image translation, where the objective is to convert an input picture from one domain to another while maintaining its key features. GANs may be used, for instance, to change pictures from day to night, transform drawings into realistic images, or change the creative style of an image.
  3. Text-to-Image Synthesis : GANs have been used to create visuals from descriptions in text. GANs may produce pictures that translate to a description given a text input, such as a phrase or a caption. This application might have an impact on how realistic visual material is produced using text-based instructions.
  4. Data Augmentation : GANs can augment present data and increase the robustness and generalizability of machine-learning models by creating synthetic data samples.
  5. Data Generation for Training : GANs can enhance the resolution and quality of low-resolution images. By training on pairs of low-resolution and high-resolution images, GANs can generate high-resolution images from low-resolution inputs, enabling improved image quality in various applications such as medical imaging, satellite imaging, and video enhancement.

Advantages of Generative Adversarial Networks (GANs)

The advantages of the GANs are as follows:

  1. Synthetic data generation: GANs can generate new, synthetic data that resembles some known data distribution, which can be useful for data augmentation, anomaly detection, or creative applications.
  2. High-quality results: GANs can produce high-quality, photorealistic results in image synthesis, video synthesis, music synthesis, and other tasks.
  3. Unsupervised learning: GANs can be trained without labeled data, making them suitable for unsupervised learning tasks, where labeled data is scarce or difficult to obtain.
  4. Versatility: GANs can be applied to a wide range of tasks, including image synthesis, text-to-image synthesis, image-to-image translation, anomaly detection, data augmentation, and others.

Disadvantages of Generative Adversarial Networks (GANs)

The disadvantages of the GANs are as follows:

  1. Training Instability: GANs can be difficult to train, with the risk of instability, mode collapse, or failure to converge.
  2. Computational Cost: GANs can require a lot of computational resources and can be slow to train, especially for high-resolution images or large datasets.
  3. Overfitting: GANs can overfit the training data, producing synthetic data that is too similar to the training data and lacking diversity.
  4. Bias and Fairness: GANs can reflect the biases and unfairness present in the training data, leading to discriminatory or biased synthetic data.
  5. Interpretability and Accountability: GANs can be opaque and difficult to interpret or explain, making it challenging to ensure accountability, transparency, or fairness in their applications.

Frequently Asked Questions (FAQs)

1. What is a Generative Adversarial Network(GAN)?

An artificial intelligence model known as a GAN is made up of two neural networks—a discriminator and a generator—that were developed in tandem using adversarial training. The discriminator assesses the new data instances for authenticity, while the generator produces new ones.

2. What are the main applications of GANs?

Generating images and videos, transferring styles, enhancing data, translating images to other images, producing realistic synthetic data for machine learning model training, and super-resolution are just a few of the many uses for GANs.

3. What challenges do GANs face?

GANs encounter difficulties such training instability, mode collapse (when the generator generates a limited range of samples), and striking the correct balance between the discriminator and generator. It’s frequently necessary to carefully build the model architecture and tune the hyperparameters.

4. How are GANs evaluated?

The produced samples’ quality, diversity, and resemblance to real data are the main criteria used to assess GANs. For quantitative assessment, metrics like the Fréchet Inception Distance (FID) and Inception Score are frequently employed.

5. Can GANs be used for tasks other than image generation?

Yes, different tasks can be assigned to GANs. Text, music, 3D models, and other things have all been generated with them. The usefulness of conditional GANs is expanded by enabling the creation of specific content under certain input conditions.

6. What are some famous architectures of GANs?

A few well-known GAN architectures are Progressive GAN (PGAN), Wasserstein GAN (WGAN), Conditional GAN (cGAN), Deep Convolutional GAN (DCGAN), and Vanilla GAN. Each has special qualities and works best with particular kinds of data and tasks.


Whether you're preparing for your first job interview or aiming to upskill in this ever-evolving tech landscape, GeeksforGeeks Courses are your key to success. We provide top-quality content at affordable prices, all geared towards accelerating your growth in a time-bound manner. Join the millions we've already empowered, and we're here to do the same for you. Don't miss out - check it out now!

Last Updated : 23 Nov, 2023
Like Article
Save Article
Previous
Next
Complete Tutorials