Skip to the content.

Handwritten digit recognition in MNIST

Handwritten digit recognition using MNIST data is the absolute first for anyone starting with CNN/Keras/Tensorflow. It is a well defined problem with a standardizd dataset, though not complex, which can be used to run deep learning models as well as other machine learning models (logistic regression or xgboost or random forest) to predict the digits.

Test Image

In this example, I’ll be going through two implementations of MNIST classification, one using a simple NN and another using a convolutional NN.

Problem Description:

The MNIST dataset is available for download at many locations, including this one. The dataset also comes preloaded with Keras and Tensorflow implementations.

Complete Code
The code for this is provided at this location on Github, comments are inserted to make it understandable. I’m just explaining here some of the key parts of the code.

Model Parameters

# Parameters for the model
#output classes - 10 in this case
epochs =12

Batch Size: This is for the Batch Gradient descent algorithm where we take up the input data in batches. Read about it here

Epochs: Number of times the complete datset is passed through the network. The data is passed in mini batches (batch_size) until all the data is passed once through the netwrok. This is 1 eopch.

Loading Data

The data can be very easily loaded into keras using this

(x_train, y_train), (x_test, y_test) = mnist.load_data()

This will load the data and will split it into 60000 train and 10000 test samples.

x_train.shape = (60000,28,28) : 60000 examples, 28 rows and 28 columns

Reshaping Data

From the source: Keras Documentation

When using this layer as the first layer in a model, provide the keyword argument input_shape (tuple of integers, does not include the sample axis), e.g. input_shape=(128, 128, 3) for 128x128 RGB pictures in data_format=”channels_last”.

Simple Neural Network

The data is fed into a simple neural network with 28*28=784 input units. There are no convolution units in this case.

The data is therefore reshaped from (60000,28,28) to (60000,784) using the following python command

x_train = x_train.reshape(60000, 784)
x_test = x_test.reshape(10000, 784)

Convolution Neural Network

The CNNs need to input the data in a specific format. mnist.load_data() supplies the MNIST digits with structure (nb_samples, 28, 28) i.e. with 2 dimensions per example representing a greyscale image 28x28.

The Convolution2D layers in Keras however, are designed to work with 3 dimensions per example. They have 4-dimensional inputs and outputs. This covers colour images (nb_samples, nb_channels, width, height), but more importantly, it covers deeper layers of the network, where each example has become a set of feature maps i.e. (nb_samples, nb_features, width, height).

The greyscale image for MNIST digits input would either need a different CNN layer design (or a param to the layer constructor to accept a different shape), or the design could simply use a standard CNN and you must explicitly express the examples as 1-channel images. The Keras team chose the latter approach, which needs the re-shape.

# Check the format: channel first or last
if K.image_data_format() =='channels_first':
    x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
    x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
    input_shape = (1, img_rows, img_cols)
    x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
    x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
    input_shape = (img_rows, img_cols, 1)

Output units to one hot vectors

The output obtained from the raw data (y_train & y_test) needs to be converted from 60000x1 to 60000x10 matrix, each column will contain the one hot encoding for the specific output unit

## Convert class vectors to binary class metrics
y_train = keras.utils.to_categorical(y_train,num_classes)
y_test = keras.utils.to_categorical(y_test,num_classes)

PLot History

keep saving the model’s accuracy and loss (training and validation) in a variable by doing this

history =,y_train, epochs = num_epochs, batch_size=batch_size, verbose=1,

history will keep saving the results for each epoch. To see what is stored in history run this


which will output

Out[66]: dict_keys(['val_loss', 'val_acc', 'loss', 'acc'])

To plot the numbers, the following code can be used

import matplotlib.pyplot as plt

plt.title('Accuracy (train and validation)')

plt.title('Loss (train and validation)')

Rest of the code is self explanatory, please let me know if any more clarification is needed on anything else.

Code Source:

Written on February 22, 2018