Prerequisites: Generative Adversarial Network
This article will demonstrate how to build a Generative Adversarial Network using the Keras library. The dataset which is used is the CIFAR10 Image dataset which is preloaded into Keras. You can read about the dataset here.
Step 1: Importing the required libraries
import numpy as np
import matplotlib.pyplot as plt
import keras
from keras.layers import Input , Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam,SGD
|
Step 2: Loading the data
#Loading the CIFAR10 data (X, y), (_, _) = keras.datasets.cifar10.load_data()
#Selecting a single class images #The number was randomly chosen and any number #between 1 to 10 can be chosen X = X[y.flatten() = = 8 ]
|
Step 3: Defining parameters to be used in later processes
#Defining the Input shape image_shape = ( 32 , 32 , 3 )
latent_dimensions = 100
|
Step 4: Defining a utility function to build the Generator
def build_generator():
model = Sequential()
#Building the input layer
model.add(Dense( 128 * 8 * 8 , activation = "relu" ,
input_dim = latent_dimensions))
model.add(Reshape(( 8 , 8 , 128 )))
model.add(UpSampling2D())
model.add(Conv2D( 128 , kernel_size = 3 , padding = "same" ))
model.add(BatchNormalization(momentum = 0.78 ))
model.add(Activation( "relu" ))
model.add(UpSampling2D())
model.add(Conv2D( 64 , kernel_size = 3 , padding = "same" ))
model.add(BatchNormalization(momentum = 0.78 ))
model.add(Activation( "relu" ))
model.add(Conv2D( 3 , kernel_size = 3 , padding = "same" ))
model.add(Activation( "tanh" ))
#Generating the output image
noise = Input (shape = (latent_dimensions,))
image = model(noise)
return Model(noise, image)
|
Step 5: Defining a utility function to build the Discriminator
def build_discriminator():
#Building the convolutional layers
#to classify whether an image is real or fake
model = Sequential()
model.add(Conv2D( 32 , kernel_size = 3 , strides = 2 ,
input_shape = image_shape, padding = "same" ))
model.add(LeakyReLU(alpha = 0.2 ))
model.add(Dropout( 0.25 ))
model.add(Conv2D( 64 , kernel_size = 3 , strides = 2 , padding = "same" ))
model.add(ZeroPadding2D(padding = (( 0 , 1 ),( 0 , 1 ))))
model.add(BatchNormalization(momentum = 0.82 ))
model.add(LeakyReLU(alpha = 0.25 ))
model.add(Dropout( 0.25 ))
model.add(Conv2D( 128 , kernel_size = 3 , strides = 2 , padding = "same" ))
model.add(BatchNormalization(momentum = 0.82 ))
model.add(LeakyReLU(alpha = 0.2 ))
model.add(Dropout( 0.25 ))
model.add(Conv2D( 256 , kernel_size = 3 , strides = 1 , padding = "same" ))
model.add(BatchNormalization(momentum = 0.8 ))
model.add(LeakyReLU(alpha = 0.25 ))
model.add(Dropout( 0.25 ))
#Building the output layer
model.add(Flatten())
model.add(Dense( 1 , activation = 'sigmoid' ))
image = Input (shape = image_shape)
validity = model(image)
return Model(image, validity)
|
Step 6: Defining a utility function to display the generated images
def display_images():
r, c = 4 , 4
noise = np.random.normal( 0 , 1 , (r * c,latent_dimensions))
generated_images = generator.predict(noise)
#Scaling the generated images
generated_images = 0.5 * generated_images + 0.5
fig, axs = plt.subplots(r, c)
count = 0
for i in range (r):
for j in range (c):
axs[i,j].imshow(generated_images[count, :,:,])
axs[i,j].axis( 'off' )
count + = 1
plt.show()
plt.close()
|
Step 7: Building the Generative Adversarial Network
# Building and compiling the discriminator discriminator = build_discriminator()
discriminator. compile (loss = 'binary_crossentropy' ,
optimizer = Adam( 0.0002 , 0.5 ),
metrics = [ 'accuracy' ])
#Making the Discriminator untrainable #so that the generator can learn from fixed gradient discriminator.trainable = False
# Building the generator generator = build_generator()
#Defining the input for the generator #and generating the images z = Input (shape = (latent_dimensions,))
image = generator(z)
#Checking the validity of the generated image valid = discriminator(image)
#Defining the combined model of the Generator and the Discriminator combined_network = Model(z, valid)
combined_network. compile (loss = 'binary_crossentropy' ,
optimizer = Adam( 0.0002 , 0.5 ))
|
Step 8: Training the network
num_epochs = 15000
batch_size = 32
display_interval = 2500
losses = []
#Normalizing the input X = (X / 127.5 ) - 1.
#Defining the Adversarial ground truths valid = np.ones((batch_size, 1 ))
#Adding some noise valid + = 0.05 * np.random.random(valid.shape)
fake = np.zeros((batch_size, 1 ))
fake + = 0.05 * np.random.random(fake.shape)
for epoch in range (num_epochs):
#Training the Discriminator
#Sampling a random half of images
index = np.random.randint( 0 , X.shape[ 0 ], batch_size)
images = X[index]
#Sampling noise and generating a batch of new images
noise = np.random.normal( 0 , 1 , (batch_size, latent_dimensions))
generated_images = generator.predict(noise)
#Training the discriminator to detect more accurately
#whether a generated image is real or fake
discm_loss_real = discriminator.train_on_batch(images, valid)
discm_loss_fake = discriminator.train_on_batch(generated_images, fake)
discm_loss = 0.5 * np.add(discm_loss_real, discm_loss_fake)
#Training the Generator
#Training the generator to generate images
#which pass the authenticity test
genr_loss = combined_network.train_on_batch(noise, valid)
#Tracking the progress
if epoch % display_interval = = 0 :
display_images()
|
Epoch 0:
Epoch 2500:
Epoch 5000:
Epoch 7500:
Epoch 10000:
Epoch 12500:
Note that the quality of images increases with each epoch.
Step 8: Evaluating the performance
The performance of the network will be evaluated by comparing the images generated on the last epoch to the original images visually.
a) Plotting the original images
#Plotting some of the original images s = X[: 40 ]
s = 0.5 * s + 0.5
f, ax = plt.subplots( 5 , 8 , figsize = ( 16 , 10 ))
for i, image in enumerate (s):
ax[i / / 8 , i % 8 ].imshow(image)
ax[i / / 8 , i % 8 ].axis( 'off' )
plt.show() |
b) Plotting the images generated on the last epoch
#Plotting some of the last batch of generated images noise = np.random.normal(size = ( 40 , latent_dimensions))
generated_images = generator.predict(noise)
generated_images = 0.5 * generated_images + 0.5
f, ax = plt.subplots( 5 , 8 , figsize = ( 16 , 10 ))
for i, image in enumerate (generated_images):
ax[i / / 8 , i % 8 ].imshow(image)
ax[i / / 8 , i % 8 ].axis( 'off' )
plt.show() |
On visually comparing the two sets of images, it can be concluded that the network is working at an acceptable level. The quality of images can be improved by training the network for more time or by tuning the parameters of the network.