Python | Classify Handwritten Digits with Tensorflow

Classifying handwritten digits is the basic problem of the machine learning and can be solved in many ways here we will implement them by using TensorFlow

Using a Linear Classifier Algorithm with tf.contrib.learn

linear classifier achieves the classification of handwritten digits by making a choice based on the value of a linear combination of the features also known as feature values and is typically presented to the machine in a vector called a feature vector.

Modules required :

NumPy:



$ pip install numpy 

Matplotlib:

$ pip install matplotlib 

Tensorflow:

$ pip install tensorflow 

Steps to follow

Step 1 : Importing all dependence

filter_none

edit
close

play_arrow

link
brightness_4
code

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
  
learn = tf.contrib.learn
  
tf.logging.set_verbosity(tf.logging.ERROR)

chevron_right


Step 2 : Importing Dataset using MNIST Data

filter_none

edit
close

play_arrow

link
brightness_4
code

mnist = learn.datasets.load_dataset('mnist')
data = mnist.train.images
labels = np.asarray(mnist.train.labels, dtype=np.int32)
test_data = mnist.test.images
test_labels = np.asarray(mnist.test.labels, dtype=np.int32)

chevron_right


after this step a dataset of mnist will be downloaded.
output :

Extracting MNIST-data/train-images-idx3-ubyte.gz
Extracting MNIST-data/train-labels-idx1-ubyte.gz
Extracting MNIST-data/t10k-images-idx3-ubyte.gz
Extracting MNIST-data/t10k-labels-idx1-ubyte.gz

Step 3 : Making dataset

filter_none

edit
close

play_arrow

link
brightness_4
code

max_examples = 10000
data = data[:max_examples]
labels = labels[:max_examples]

chevron_right


Step 4 : Displaying dataset using MatplotLib



filter_none

edit
close

play_arrow

link
brightness_4
code

def display(i):
    img = test_data[i]
    plt.title('label : {}'.format(test_labels[i]))
    plt.imshow(img.reshape((28, 28)))
      
# image in TensorFlow is 28 by 28 px
display(0)

chevron_right


To display data we can use this function – display(0)
output :

Step 5 : Fitting data, using linear classifier

filter_none

edit
close

play_arrow

link
brightness_4
code

feature_columns = learn.infer_real_valued_columns_from_input(data)
classifier = learn.LinearClassifier(n_classes=10
                                    feature_columns=feature_columns)
classifier.fit(data, labels, batch_size=100, steps=1000)

chevron_right


Step 6 : Evaluate accuracy

filter_none

edit
close

play_arrow

link
brightness_4
code

classifier.evaluate(test_data, test_labels)
print(classifier.evaluate(test_data, test_labels)["accuracy"])

chevron_right


Output :

0.9137

Step 7 : Predicting data

filter_none

edit
close

play_arrow

link
brightness_4
code

prediction = classifier.predict(np.array([test_data[0]], 
                                         dtype=float), 
                                         as_iterable=False)
print("prediction : {}, label : {}".format(prediction, 
      test_labels[0]) )

chevron_right


Output :

prediction : [7], label : 7

Full Code for classifying handwritten

filter_none

edit
close

play_arrow

link
brightness_4
code

# importing libraries
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
  
learn = tf.contrib.learn
tf.logging.set_verbosity(tf.logging.ERROR)\
  
# importing dataset using MNIST
# this is how mnist is used mnist contain test and train dataset
mnist = learn.datasets.load_dataset('mnist')
data = mnist.train.images
labels = np.asarray(mnist.train.labels, dtype = np.int32)
test_data = mnist.test.images
test_labels = np.asarray(mnist.test.labels, dtype = np.int32)
  
max_examples = 10000
data = data[:max_examples]
labels = labels[:max_examples]
  
# displaying dataset using Matplotlib
def display(i):
    img = test_data[i]
    plt.title('label : {}'.format(test_labels[i]))
    plt.imshow(img.reshape((28, 28)))
      
# img in tf is 28 by 28 px
# fitting linear classifier
feature_columns = learn.infer_real_valued_columns_from_input(data)
classifier = learn.LinearClassifier(n_classes = 10
                                    feature_columns = feature_columns)
classifier.fit(data, labels, batch_size = 100, steps = 1000)
  
# Evaluate accuracy
classifier.evaluate(test_data, test_labels)
print(classifier.evaluate(test_data, test_labels)["accuracy"])
  
prediction = classifier.predict(np.array([test_data[0]], 
                                         dtype=float), 
                                         as_iterable=False)
print("prediction : {}, label : {}".format(prediction, 
      test_labels[0]) )
  
if prediction == test_labels[0]:
     display(0)

chevron_right


Using Deep learning with tf.keras

Deep learning is a subpart of machine learning and artificial intelligence which is also known as deep neural network this networks capable of learning unsupervised from provided data which is unorganized or unlabeled. today, we will implement a neural network in TensorFlow to classify handwritten digit.

Modules required :



NumPy:

$ pip install numpy 

Matplotlib:

$ pip install matplotlib 

Tensorflow:

$ pip install tensorflow 

Steps to follow

Step 1 : Importing all dependence

filter_none

edit
close

play_arrow

link
brightness_4
code

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

chevron_right


Step 2 : Import data and normalize it

filter_none

edit
close

play_arrow

link
brightness_4
code

mnist = tf.keras.datasets.mnist
(x_train,y_train) , (x_test,y_test) = mnist.load_data()
  
x_train = tf.keras.utils.normalize(x_train,axis=1)
x_test = tf.keras.utils.normalize(x_test,axis=1)

chevron_right


Step 3 : view data

filter_none

edit
close

play_arrow

link
brightness_4
code

def draw(n):
    plt.imshow(n,cmap=plt.cm.binary)
    plt.show() 
      
draw(x_train[0])

chevron_right


Step 4 : make a neural network and train it

filter_none

edit
close

play_arrow

link
brightness_4
code

#there are two types of models
#sequential is most common, why?
  
model = tf.keras.models.Sequential()
  
model.add(tf.keras.layers.Flatten(input_shape=(28, 28)))
#reshape
  
model.add(tf.keras.layers.Dense(128,activation=tf.nn.relu))
model.add(tf.keras.layers.Dense(128,activation=tf.nn.relu))
model.add(tf.keras.layers.Dense(10,activation=tf.nn.softmax))
  
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy']
              )
model.fit(x_train,y_train,epochs=3)

chevron_right


Step 5 : check model accuracy and loss



filter_none

edit
close

play_arrow

link
brightness_4
code

val_loss,val_acc = model.evaluate(x_test,y_test)
print("loss-> ",val_loss,"\nacc-> ",val_acc)

chevron_right


Step 6 : prediction using model

filter_none

edit
close

play_arrow

link
brightness_4
code

predictions=model.predict([x_test])
print('lable -> ',y_test[2])
print('prediction -> ',np.argmax(predictions[2]))
  
draw(x_test[2])

chevron_right


saving and testing model

saving the model

filter_none

edit
close

play_arrow

link
brightness_4
code

#saving the model
# .h5 or .model can be used
  
model.save('epic_num_reader.h5')

chevron_right


loading the saved model

filter_none

edit
close

play_arrow

link
brightness_4
code

new_model = tf.keras.models.load_model('epic_num_reader.h5')

chevron_right


prediction using new model

filter_none

edit
close

play_arrow

link
brightness_4
code

predictions=new_model.predict([x_test])
  
  
print('lable -> ',y_test[2])
print('prediction -> ',np.argmax(predictions[2]))
  
draw(x_test[2])

chevron_right


Attention geek! Strengthen your foundations with the Python Programming Foundation Course and learn the basics.

To begin with, your interview preparations Enhance your Data Structures concepts with the Python DS Course.




My Personal Notes arrow_drop_up

A Computer Science and Engineering undergraduate student at IERT, Allahabad with an interest in Programming, Data Science/AI and web development

If you like GeeksforGeeks and would like to contribute, you can also write an article using contribute.geeksforgeeks.org or mail your article to contribute@geeksforgeeks.org. See your article appearing on the GeeksforGeeks main page and help other Geeks.

Please Improve this article if you find anything incorrect by clicking on the "Improve Article" button below.



Improved By : itsvinayak