DEV Community

Cover image for Data Augmentation Techniques for Improving the Robustness of Image Classifiers
Ransika Silva
Ransika Silva

Posted on

Data Augmentation Techniques for Improving the Robustness of Image Classifiers

Introduction

While training image classification models, we commonly encounter issues such as insufficient training data, overfitting, and poor generalization on unseen images. One way to overcome some of these challenges is through data augmentation, a powerful technique that can improve the robustness of our models. In this article, we are going to discuss some of the common examples of data augmentation techniques and how we can implement those methods in Python using TensorFlow.

What is Data Augmentation?

One approach is data augmentation; it helps to make altered images out of every input image, to increase the size and diversity of the dataset. We use different transformations of the images to simulate various differences that the model might see in the real world, allowing it to learn more robust and generalizable features [1].

Some common data augmentation techniques include:

  • Flipping
  • Rotation
  • Scaling
  • Cropping
  • Translation
  • Adding noise
  • Adjusting brightness or contrast
  • Perspective transformations

The idea is that a robust classifier should not be affected by all of these kinds of transformations; it should correctly recognize a photograph of a dog regardless of whether the photograph is mirrored, turned, or minimally zoomed-in.

Implementing data augmentation with TensorFlow

TensorFlow provides a range of preprocessing layers inside the tf.image module that can be used to build a data augmentation pipeline [2]. These preprocessing layers can easily be applied to the input data by means of the tf.data API.

Here, we provide a description of designing a multi-transformation augmentation pipeline.

import tensorflow as tf

def augment(image, label):  
    image = tf.image.resize(image, [224, 224])
    image = image.random_flip_lr
    image = tf.image.random_brightness(image, max_delta=0.5)
    image = tf.image.random_contrast(image, lower=0.2, upper=1.8)
    image = tf.image.random_rotation(image, 0.2)
    return image, tag

train_ds = train_ds.map(func=aug
Enter fullscreen mode Exit fullscreen mode

This pipeline will perform a sequence of transformations to every image:

  • Resize the image to 224x224
  • Randomly flip the image horizontally
  • Randomly adjust the brightness by up to 0.5
  • Adjust the contrast at random between 0.2 to 1.8
  • Randomly rotate the image by up to 0.2 radians, that is 11.5 degrees

We can then supply this enriched database to the model while continuing to train.

model.fit(train_ds, epochs=10)
Enter fullscreen mode Exit fullscreen mode

Recommendations for Success with Data Augmentation

  • Begin with the domain-independent base augmentations like flipping and rotation prior to tackling the stronger or domain-related augmentations [3].
  • One must also consider the kind of augmentation that can change the image class itself, like turning the digit '9' into a digit '6'.
  • Explore the range of augmentation parameters to find a harmonious balance; inept usage can render the gains negligible while excessive usage can impede the learning of meaningful patterns by the model.
  • Data augmentation can also be applied to the test and the validation sets to provide a complete analysis. Deterministic transformations like the centre cropping must, however, be applied instead of random cropping [1].
  • Data augmentation is no silver bullet; its strength comes into effect when combined with proper regularization, well-crafted architecture of the model, and thorough training

Visualizing Augmented Images

Visualizing examples of the image you have augmented is the best way to guarantee that the transformation looks like you want. Using the eager execution of TensorFlow, you can add the augmentation to the image and then graph the resultant image [4].

Here's a quick example of generating and plotting 9 augmented versions of an input image:

import matplotlib.pyplot as plt

for i, (augmented_image, *) in enumerate(train*ds.take(9)):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(augmented_image)
    plt.axis("off")
Enter fullscreen mode Exit fullscreen mode

Conclusion

Data augmentation is a priceless component of the toolkit of all machine learning practitioners that can improve the performance of image classifiers especially if confronted with the challenges of small or imbalanced datasets. It helps to promote the learning of stronger and transferable features by increasing the variability experienced within the training set artificially.

When paired with hyperparameter tuning, regularization techniques, and transfer learning strategies, augmentation can significantly improve the capacity of a model to deal with the complex dynamic vision present in the real world [5]. With the number of augmentation techniques out there being extensive, you must consider their suitability to your specific domain problem and monitor their impact on the performance of the model very closely. Good luck with the augmentation!

References

[1] Connor Shorten and Taghi M. Khoshgoftaar. "A survey on Image Data Augmentation for Deep Learning." Journal of Big Data 6.1 (2019): 1-48.
[2] "tf.image: Image Preprocessing - TensorFlow Core v2.11.0." TensorFlow, https://www.tensorflow.org/api_docs/python/tf/image.
[3] Aleju, Marcus. "Data Augmentation for Deep Learning." Medium, 19 July 2020, https://mxbi.medium.com/data-augmentation-for-deep-learning-4fe21d1a4eb9.
[4] "Eager Execution - TensorFlow Core v2.11.0." TensorFlow, https://www.tensorflow.org/guide/eager.
[5] Wang, Jason, and Luis Perez. "The effectiveness of data augmentation in image classification using deep learning." arXiv preprint arXiv:1712.04621 (2017).

Top comments (0)