Keras data generators and how to use them in TensorFlow

If you have ever tried to train a neural network, you probably have encountered a situation where you try to load a dataset but there is not enough memory in your machine. This problem has become very common and is already one of the challenges in the field of computer vision where large datasets of images are processed.

In this tutorial, we focus on how to build data generators for loading and processing images in Keras and save the day.

Data Generators In Keras

So what are Data Generators or Image Data Generators?

Essentially, it is a class under Keras which is very useful in the field of image processing.

It generates batches of tensor image data with real-time data augmentation. Some of the augmentation techniques include horizontal and vertical flipping, cropping of images, shifting of images, etc.

In case you need a comprehensive explanation about image augmentation, I advise you have a look at this repository

Now, to demonstrate the hands-on use of ImageDataGenerator, I will work with the infamous Cats and Dogs Dataset.

The dataset I have used is in a very specific format which allows us to load it using ImageDataGenerator without any hassle at all.

data/
    train/
        dogs/
            dog001.jpg
            dog002.jpg 
            ...
        cats/
            cat001.jpg
            cat002.jpg
            ...
    validation/
        dogs/
            dog001.jpg
            dog002.jpg
            ...
        cats/
            cat001.jpg
            cat002.jpg
            ...

First things first, we have to import the ImageDataGenerator class to use it later.

from keras.preprocessing.image import ImageDataGenerator

flow_from_directory

To load the dataset we will use the .flow_from_directory method which takes in some parameters, the most relevant being :

  •  directory i.e the path of the data on the machine.

See the Python code given below:

train_datagen = ImageDataGenerator(rescale = 1./255,
                                   shear_range = 0.2,
                                   zoom_range = 0.2,
                                   horizontal_flip = True)

test_datagen = ImageDataGenerator(rescale = 1./255)

training_set = train_datagen.flow_from_directory('dataset/training_set',
                                                 target_size = (64,64),
                                                 batch_size = 32,
                                                 class_mode = 'binary')

test_set = test_datagen.flow_from_directory('dataset/test_set',
                                            target_size = (64,64),
                                            batch_size = 32,
                                            class_mode = 'binary')

Go over this code block, you will find many augmentation methods used on the training set which were discussed briefly in this tutorial.

Some points to note from this code:

  • Augmentation is done exclusively on the training dataset but not the test set. I suggest you introspect it yourself and try to understand the possible reasons for the same.
  • target_size is changed from default as the images in the dataset do not match the default argument.
  • class_mode is changed to ‘binary’ from ‘categorial’ because the task in hand is a binary classification problem and not a multiclass one.

Training a model using ImageDataGenerator is simple. After the model architecture is complete, you have to call the .fit_generator method off of the model with some additional parameters.

classifier.fit_generator(training_set,# the training set
                         samples_per_epoch = 8000,
                         nb_epoch = 10,# number of epochs 
                         validation_data = test_set,# the test set
                         nb_val_samples = 2000)

And that’s it!! Once you know it, it’s simple enough to reuse it later.

In case you want the full code of the classification task at hand along with prediction on unseen data, here is the repository for your reference.

 

Leave a Reply

Your email address will not be published. Required fields are marked *