Hauke Brammer
05. April 2018
12 min

Image classification with CNNs and small augmented datasets

Machine learning requires lots of data. However, can you get some meaningful results from just a hundred examples? And if so, how do you do that?

The easiest way to train your machine learning algorithm on a small dataset, is to make the dataset bigger. This might sound contraditory, but in this post I will show you a simple way to augment your small image datasets with the help of Keras.


One of the classic examples in image recognition is the MNIST dataset. It consists of a collection of 70,000 grayscale images with a fixed size of 28×28 pixels. Each image shows a handwritten digit between 0 and 9.

In this post, we will use Zalandos Fashion-MNIST dataset. This dataset is a direct replacement for the regular MNIST dataset but offers a bigger challenge than it’s predeccessor for which error rates below one percent are now common. The 70,000 images in the new dataset have the same dimensions and are also divided into ten classes. Instead of handwritten digits, given the fact that the dataset comes from Zalando, you may already have guessed that the images depicts images of clothes and shoes.

Zalando introduced this dataset in a 2017 paper to offer an alternative to the overused MNIST dataset.

Since Fashion-MNIST conveniently has the same dimensions as regular MNIST it was already integrated in a bunch of machine learning libraries like Tensorflow or Pytorch.

So let’s get started!

Importing, normalizing, visualizing…

First we let Keras download the dataset for us.

The images consist of grayscale values between 0.0 and 255.0. We normalize them by dividing the whole data arrays by 255.0.

The images are stored in in 784 columns but were originally 28 by 28 pixels. We will later reshape them to there original format.

Lets take a look now at our nice dataset: For easier plotting of the images in the dataset, we define a plotting function that we will use quite often to visualize intermediate results.

Let’s also define a function that we can use to pick a random subset from the training data.

These are a hundred examples for our training data. Each row is one category with ten examples.

Label Description
0 T-shirt/top
1 Trouser
2 Pullover
3 Dress
4 Coat
5 Sandal
6 Shirt
7 Sneaker
8 Bag
9 Ankle boot

Can you tell apart every coat from a pullover? I certainly can’t. But lets see if a small convolutional neural net can.

But let’s see if a small convolutional neural net can.


Our model will consist of just two stacks of two convolution layers each. Each layer has a ReLU activation. After each stack we put a max-pooling layer.

On top of these convolution layers we put two fully connected layers. The last layer gets one unit per category, as it has to decide in which category each image belongs. As loss function we use categorical_crossentropy to train our model.

After compiling the model, we can see that is has a total of 126,122 parameters that can be used for training.

Baseline with full data

To see how our tests with smaller datasets perform in comparison with the full original dataset we first need to establish a baseline.

For that we transform all of our data to a format that tensorflow can understand: The first dimension are the individual training images and the second and third dimensions are the x- and y-axis of the individual image. The fourth dimension would consist of the different color channels, but we currently working with only one since we only work with grayscale images here.

We set the number of epochs to 30. On a okayish laptop that will take 30 minutes to run. If you have a better machine feel free to increase the number of epochs and see what happens.

Now we train the model on our complete training data and use the whole test data as validation. For nicer visualization of the training progress we add the TQDMNotebookCallback to the callback list. I didn’t embed the progress visualization in the post. They are included in the original jupyter notebook (see link at the bottom).

Now we have a baseline against which we can compare our augmented data.

Creating training data with augmentation

Augmentation of image datasets is really easy with with the keras.preprocessing.image.ImageDataGenerator class.

With the ImageDataGenerator you can apply random transformations to a given set of images. By this you can effectively increase the number of images you can use for training.

What makes the ImageDataGenerator extra convenient is that we can use it as direct input to the model.fit() function without generating and saving a bunch of images first.

Now I want to take a closer look at the transformations you can apply:


Let’s take just one image first to see what the transformers do to it. We can use all of these transformers via the ImageDataGenerator or on their own if we want to.


random_shift allows you to randomly shift by a given fraction of the imagesize in each direction. Here we specify wrg=0.1 and hrg=0.2. That means that we shift up to 0.2 x imagesize (0.2 x 28 = 5.6) pixel up or down and up to 0.1 x imagesize (0.1 x 28 = 2.8) pixel left or right.

In all transformer functions you can specify row_axis, col_axis and channel_axis according to the array of images you pass into the function.

Also you can specify the a fill_mode for pixel values that are not originally in the image. For instance if we shift up an image by 3 pixels we need to fill the new 3 rows of pixels with some value.

To illustrate the different values of fill_mode I will use the following example image: 1234

You can set fill_mode to one the following values:

  • constant: Fill the missing values with a constant value. You can specify the constant value with the option cval. Otherwise 0.0 will be used. Example: With cval=5 55555555 |1234| 55555555
  • nearest: The nearest non-empty pixel value is used. Example: 11111111 |1234| 44444444
  • reflect: The image is reflected on the original image border and the values are filled in accordingly. Example:12344321 |1234| 43211234
  • wrap: The original image is repeated multiple times for the empty pixels. Example:
    12341234 |1234| 12341234


With the random_rotation transformer we can rotate the image randomly by up to x degrees clockwise or counterclockwise. Here we specify a maximum rotation of 20 degrees.


The random_shear functions shears an image with a random shearing angle that is calculated from the given `intensity.

Note that shearing is different from just rotation since it deforms the given image by multiplying it with the following transformation matrix:

1 & -sin(m) & 0 \\
0 & cos(m) & 0 \\
0 & 0 & 1

Where \(m\) is \(x * \pi / 180\) with \(x\) being a random float in \([- intensity, intensity]\).


random_zoom zooms in and out of an image. But it don’t use the same zooming factor for horizontal and vertical zoom but two independently random values. We can specify a minimum (here 0.7) and a maximum value (here 1.3) for the zoom. A value bigger than 1.0 zooms in, thus making the object in the image bigger. A value smaller than 1.0 zooms out.


Now we combine every transformation that we just did in one ImageDataGenerator. It is also possibly to allow a flip of the image either horizontally or vertically. For now we disallow that option.

When we start the ImageDataGenerator it runs in an endless loop. But since we just want a few example we let it run in a for loop and break out of it when we have collected enough examples.

This allows us to create 100 images from just one image.

Training with augmented datasets

To test the effectiveness of the augmentation of our dataset we will try to train our model on randomly sampled training sets of different sizes. We will use 1, 10, 100 and 1000 examples per class and train with each reduced dataset for 30 epochs.

For this we first define a image generator like above.

Now we define a function that will train a model with a specified number of samples per category: First, we randomly pick a number of samples from each category from the original training dataset with the function we defined earlier.

Then we feed this sample of training data in the ImageDataGenerator and initialize it. We define a batchsize of 30 which means that the generator will generate 30 randomly transformed on each call.

We create a new Model of the same structure as we defined it earlier for the original training data.

Finally we train the model on data from the generator with the fit_generator() function instead of the “standard” fit(). We choose 2000 steps per epoch to get a total of 30 x 2000 = 60,000 training examples like in the original training dataset. But instead of the 60,000 totally different images we now have images that are generated from a much, much smaller set of images.

Now lets test our model with a bunch of examples.

One sample per category

Ten samples per category

Hundred samples per category

Thousand samples per category


Now we have tested with different datasets of increasing sizes. Let’s plot the results for training and validation accuracy:

After training the model with the given number of training samples for 30 epochs we reach the following final accuracies:

Number of samples per category Accuracy
6000 (original dataset) 0.9093
1000 0.8896
100 0.8161
10 0.6774
1 0.3099

While it is still true that more data leads to better results we can reach about 81% accuracy with less than two percent of the original dataset when we use an image generator to augment our small test datasets.

Smaller training data sets lead to stronger overfitting problems, as we can see in the high training accuracy but low validation accuracy. Data augmentation is one way to mitigate this problem. We could adapt other methods such as dropouts and regularization to further improve our results.

There are also other possible solutions to working with small datasets. You could, for example, retrain an available and already trained network to fit your specific use case (this is something I will demonstrate in an upcoming post).

In this post I showed you how you can use the Keras ImageDataGenerator to augment small image datasets really easily and efficiently. The full Jupyter notebook with all the code that was produced in this post is available at Github. Did you use the ImageDataGenerator in one of your projects? Did you even came up with your own augmentation method for image data? Please let me know.

Comment article