Building an Auxiliary GAN using Keras and Tensorflow

Prerequisites: Generative Adversarial Network

This article will demonstrate how to build an Auxiliary Generative Adversarial Network using the Keras and TensorFlow libraries. The dataset which is used is the MNIST Image dataset pre-loaded into Keras. 

Step 1: Setting up the environment 

Step 1 : Open Anaconda promt in Administrator mode. 

Step 2 : Create a virtual environment using the command : conda create --name acgan python=3.7 

Step 3 : Then, activate the environment using the command : conda activate acgan

Step 4 : Install the following libraries -
         4.1 - Tensorflow --> pip install tensorflow==2.1
         4.2 - Keras      --> pip install keras==2.3.1      

Step 2: Importing the required libraries

filter_none

edit
close

play_arrow

link
brightness_4
code

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D
from keras.models import Sequential, Model
from keras.layers.advanced_activations import LeakyReLU
  
from tensorflow.keras.optimizers import Adam
  
import matplotlib.pyplot as plt
import numpy as np

chevron_right


Step 3: Defining parameters to be used in later processes



filter_none

edit
close

play_arrow

link
brightness_4
code

# Defining the Input shape 
image_shape = (28, 28, 1)
classes = 10
latent_dim = 100
  
# Defining the optimizer and the losses  
optimizer = Adam(0.0002, 0.5)
losses = ['binary_crossentropy','sparse_categorical_crossentropy']

chevron_right


Step 4: Defining a utility function to build the Generator

filter_none

edit
close

play_arrow

link
brightness_4
code

def build_generator():
  
    model = Sequential()
      
    # Building the input layer 
    model.add(Dense(128 * 7 * 7, activation="relu", input_dim=latent_dim))
    model.add(Reshape((7, 7, 128)))
    model.add(BatchNormalization(momentum=0.82))
    model.add(UpSampling2D())
  
    model.add(Conv2D(128, (3,3), padding="same"))
    model.add(BatchNormalization(momentum=0.82))
    model.add(Activation("relu"))
    model.add(UpSampling2D())
  
    model.add(Conv2D(64, (3,3), padding="same"))
    model.add(BatchNormalization(momentum=0.82))
    model.add(Activation("relu"))
      
    model.add(Conv2D(1, (3,3), padding='same'))
    model.add(Activation("tanh"))
  
    # Generating the output image
    noise = Input(shape=(latent_dim,))
    label = Input(shape=(1,), dtype='int32')
    z = Flatten()(Embedding(classes, latent_dim)(label))
  
    model_input = multiply([noise, z])
    image = model(model_input)
  
    return Model([noise, label], image)

chevron_right


Step 5: Defining a utility function to build the Discriminator

filter_none

edit
close

play_arrow

link
brightness_4
code

def build_discriminator():
  
    model = Sequential()
      
    # Building the input layer 
    model.add(Conv2D(16, (3,3), strides=2, input_shape=image_shape, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.25))
  
    model.add(Conv2D(32, (3,3), strides=2, padding="same"))
    model.add(ZeroPadding2D(padding=((0,1),(0,1))))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.25))
  
    model.add(BatchNormalization(momentum=0.8))
  
    model.add(Conv2D(64, (3,3), strides=2, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.25))
  
    model.add(BatchNormalization(momentum=0.8))
  
    model.add(Conv2D(128, (3,3), strides=1, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.25))
  
    model.add(Flatten())
  
    image = Input(shape=image_shape)
  
    # Extract features from images
    features = model(image)
  
    # Building the output layer 
    validity = Dense(1, activation="sigmoid")(features)
    label = Dense(classes, activation="softmax")(features)
  
    return Model(image, [validity, label])

chevron_right


Step 6: Defining a utility function to display the generated images

filter_none

edit
close

play_arrow

link
brightness_4
code

def display_images(): 
    r = 10
    c = 10
    noise = np.random.normal(0, 1, (r * c,latent_dim)) 
  
    new_labels = np.array([num for _ in range(r) for num in range(c)])
    gen_images = generator.predict([noise, new_labels])
  
    # Rescale images 0 - 1
    gen_images = 0.5 * gen_images + 0.5
  
    fig, axs = plt.subplots(r, c) 
    count = 0
    for i in range(r): 
        for j in range(c): 
            axs[i,j].imshow(gen_images[count,:,:,0], cmap='gray'
            axs[i,j].axis('off'
            count += 1
    plt.show() 
    plt.close()

chevron_right


Step 7: Building and Training the AC-GAN 

filter_none

edit
close

play_arrow

link
brightness_4
code

def train_acgan(epochs, batch_size=128, sample_interval=50):
  
    # Load the dataset
    (X, y), (_, _) = mnist.load_data()
  
    # Configure inputs
    X = X.astype(np.float32)
    X = (X - 127.5) / 127.5
    X = np.expand_dims(X, axis=3)
    y = y.reshape(-1, 1)
  
    # Adversarial ground truths
    valid = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))
  
    for epoch in range(epochs):
  
        # Select a random batch of images
        index = np.random.randint(0, X.shape[0], batch_size)
        images = X[index]
  
        # Sample noise as generator input
        noise = np.random.normal(0, 1, (batch_size, latent_dim))
  
        # The labels of the digits that the generator tries to create an
        # image representation of
        new_labels = np.random.randint(0, 10, (batch_size, 1))
  
        # Generate a half batch of new images
        gen_images = generator.predict([noise, new_labels])
  
        image_labels = y[index]
  
        # Training the discriminator
        disc_loss_real = discriminator.train_on_batch(
          images, [valid, image_labels])
        disc_loss_fake = discriminator.train_on_batch(
          gen_images, [fake, new_labels])
        disc_loss = 0.5 * np.add(disc_loss_real, disc_loss_fake)
  
        # Training the generator
        gen_loss = combined.train_on_batch(
          [noise, new_labels], [valid, new_labels])
  
        # Print the accuracies 
        print ("%d [acc.: %.2f%%, op_acc: %.2f%%]" % (
          epoch, 100 * disc_loss[3], 100 * disc_loss[4]))
  
        # display at every defined epoch interval
        if epoch % sample_interval == 0:
            display_images()

chevron_right


Step 8: Building the Generative Adversarial Network

filter_none

edit
close

play_arrow

link
brightness_4
code

# Build and compile the discriminator
discriminator = build_discriminator()
discriminator.compile(loss=losses,
    optimizer=optimizer,
    metrics=['accuracy'])
  
# Build the generator
generator = build_generator()
  
# Defining the input for the generator 
#and generating the images 
noise = Input(shape=(latent_dim,))
label = Input(shape=(1,))
image = generator([noise, label])
  
# Disable the Discriminator 
# For the combined model we will only train the generator
discriminator.trainable = False
  
# The discriminator takes in the generated image
# as input and determines validity
# and the label of that image
valid, target_label = discriminator(image)
  
# The combined model (both generator and discriminator)
# Training the generator to fool the discriminator
combined = Model([noise, label], [valid, target_label])
combined.compile(loss=losses, optimizer=optimizer)
  
train_acgan(epochs=14000, batch_size=32, sample_interval=2000)

chevron_right


Output (At every 2000 epoch interval):

Epoch 0



Epoch 2000

Epoch 4000

Epoch 6000

Epoch 8000

Epoch 10000

Epoch 12000

Epoch 14000

FINAL RESULT

On visually observing the progression of generated images, it can be concluded that the network is working at an acceptable level. The quality of images can be improved by training the network for more time or by tuning the parameters of the network. For any doubts/queries, comment below. 

Attention geek! Strengthen your foundations with the Python Programming Foundation Course and learn the basics.

To begin with, your interview preparations Enhance your Data Structures concepts with the Python DS Course.




My Personal Notes arrow_drop_up

Check out this Author's contributed articles.

If you like GeeksforGeeks and would like to contribute, you can also write an article using contribute.geeksforgeeks.org or mail your article to contribute@geeksforgeeks.org. See your article appearing on the GeeksforGeeks main page and help other Geeks.

Please Improve this article if you find anything incorrect by clicking on the "Improve Article" button below.


Article Tags :
Practice Tags :


Be the First to upvote.


Please write to us at contribute@geeksforgeeks.org to report any issue with the above content.