CIFAR-10 Image Classification in TensorFlow


In this article, we are going to discuss how to classify images using TensorFlow. Image Classification is a method to classify the images into their respective category classes. CIFAR-10 Dataset as it suggests has 10 different categories of images in it. There is a total of 60000 images of 10 different classes naming Airplane, Automobile, Bird, Cat, Deer, Dog, Frog, Horse, Ship, Truck. All the images are of size 32×32. There are in total 50000 train images and 10000 test images.

To build an image classifier we make use of tensorflow‘ s keras API to build our model. In order to build a model, it is recommended to have GPU support, or you may use the Google colab notebooks as well.

Stepwise Implementation:

import tensorflow as tf  
# Display the version
# other imports
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Input, Conv2D, Dense, Flatten, Dropout
from tensorflow.keras.layers import GlobalMaxPooling2D, MaxPooling2D
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.models import Model



The output of the above code should display the version of tensorflow you are using eg 2.4.1 or any other. 

# Load in the data
cifar10 = tf.keras.datasets.cifar10
# Distribute it to train and test set
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
print(x_train.shape, y_train.shape, x_test.shape, y_test.shape)


The output of the above code will display the shape of all four partitions and will look something like this

Here we can see we have 5000 training images and 1000 test images as specified above and all the images are of 32 by 32 size and have 3 color channels i.e. images are color images. As well as it is also visible that there is only a single label assigned with each image.

Another thing we want to do is to flatten(in simple words rearrange them in form of a row) the label values using the flatten() function. 

# Reduce pixel values
x_train, x_test = x_train / 255.0, x_test / 255.0
# flatten the label values
y_train, y_test = y_train.flatten(), y_test.flatten()

# visualize data by plotting images
fig, ax = plt.subplots(5, 5)
k = 0
for i in range(5):
    for j in range(5):
        ax[i][j].imshow(x_train[k], aspect='auto')
        k += 1


Though the images are not clear there are enough pixels for us to specify which object is there in those images. 

# number of classes
K = len(set(y_train))
# calculate total number of classes
# for output layer
print("number of classes:", K)
# Build the model using the functional API
# input layer
i = Input(shape=x_train[0].shape)
x = Conv2D(32, (3, 3), activation='relu', padding='same')(i)
x = BatchNormalization()(x)
x = Conv2D(32, (3, 3), activation='relu', padding='same')(x)
x = BatchNormalization()(x)
x = MaxPooling2D((2, 2))(x)
x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)
x = BatchNormalization()(x)
x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)
x = BatchNormalization()(x)
x = MaxPooling2D((2, 2))(x)
x = Conv2D(128, (3, 3), activation='relu', padding='same')(x)
x = BatchNormalization()(x)
x = Conv2D(128, (3, 3), activation='relu', padding='same')(x)
x = BatchNormalization()(x)
x = MaxPooling2D((2, 2))(x)
x = Flatten()(x)
x = Dropout(0.2)(x)
# Hidden layer
x = Dense(1024, activation='relu')(x)
x = Dropout(0.2)(x)
# last hidden layer i.e.. output layer
x = Dense(K, activation='softmax')(x)
model = Model(i, x)
# model description


Our model is now ready, it’s time to compile it. We are using model.compile() function to compile our model. For the parameters, we are using 

# Compile

# Fit
r =
  x_train, y_train, validation_data=(x_test, y_test), epochs=50)


The model will start training, and it will look something like this 

# Fit with data augmentation
# Note: if you run this AFTER calling
# the previous
# it will CONTINUE training where it left off
batch_size = 32
data_generator = tf.keras.preprocessing.image.ImageDataGenerator(
  width_shift_range=0.1, height_shift_range=0.1, horizontal_flip=True)
train_generator = data_generator.flow(x_train, y_train, batch_size)
steps_per_epoch = x_train.shape[0] // batch_size
r =, validation_data=(x_test, y_test),
              steps_per_epoch=steps_per_epoch, epochs=50)


The model will start training for 50 epochs. Though it is running on GPU it will take at least 10 to 15 minutes.

# Plot accuracy per iteration
plt.plot(r.history['accuracy'], label='acc', color='red')
plt.plot(r.history['val_accuracy'], label='val_acc', color='green')


Let’s make a prediction over an image from our model using model.predict() function. Before sending the image to our model we need to again reduce the pixel values between 0 and 1 and change its shape to (1,32,32,3) as our model expects the input to be in this form only. To make things easy let us take an image from the dataset itself. It is already in reduced pixels format still we have to reshape it (1,32,32,3) using reshape() function. Since we are using data from the dataset we can compare the predicted output and original output. 

# label mapping
labels = '''airplane automobile bird cat deerdog frog horseship truck'''.split()
# select the image from our test dataset
image_number = 0
# display the image
# load the image in an array
n = np.array(x_test[image_number])
# reshape it
p = n.reshape(1, 32, 32, 3)
# pass in the network for prediction and
# save the predicted label
predicted_label = labels[model.predict(p).argmax()]
# load the original label
original_label = labels[y_test[image_number]]
# display the result
print("Original label is {} and predicted label is {}".format(
    original_label, predicted_label))


Now we have the output as Original label is cat and the predicted label is also cat.

Let’s check it for some label which was misclassified by our model, e.g. for image number 5722 we receive something like this: 

Finally, let’s save our model using function as an h5 file. If you are using Google colab you can download your model from the files section.

# save the model'geeksforgeeks.h5')

Hence, in this way, one can classify images using Tensorflow.

