Julius Hernandez Alvarado/

Sloth or Pastry? Using PyTorch and Deep Learning for Image Classification


Introduction to Image Classification with PyTorch

We'll be using computer vision to answer the question that never gets old on the internet: is it a sloth or a pain au chocolat? This is a binary image classification task.

In the age of deep learning, data scientists and machine learning engineers seldom create and train neural networks from scratch. A big chunk of what goes into performing a machine learning task, however, is collecting, preparing, and loading data to feed into a model. We'll perform a little bit of fine-tuning on the model, but this will not be the focus of the training.

In this session, we'll be adapting code from's tutorials on loading custom datasets to load a dataset we have collected into PyTorch.

We'll then use this tutorial on transfer learning to perform an image processing task using a mostly-pretrained model, which we'll fine tune.

Specifically, we'll be labeling images with one of two labels: sloth, or pain_au_chocolat.

Package Imports

Like all great Python projects, ours too, will start with some package imports! We'll use:

  • NumPy for manipulating numerical arrays
  • Matplotlib.pyplot for plotting
  • time, which provides time-related functions
  • os, a way of providing functionality that interacts with the operating system
  • copy, for copying objects
  • various packages from torch, including:
    • torch
    • torch.nn, which contains the basic building blocks for neural networks
    • torch.optim, a package containing various optimization algorithms for PyTorch
    • lr_scheduler from torch.optim, for adjusting the learning rate based on the number of epochs
    • torch.backends.cudnn as cudnn, a means for PyTorch to talk to the GPU (although GPUs may not be supported in your workspace)
  • torchvision, which provides additional functionalities to manipulate and process images, including
    • datasets, which contains built in datasets
    • models, containing models for various tasks, including image processing
    • transforms, which we'll use to transform images in preparation for image processing
# Package imports go here
import numpy as np
import matplotlib.pyplot as plt

import time 
import os
import copy 

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn 

import torchvision
from torchvision import datasets, models, transforms


For fast runtime, let's begin our project by setting cudnn.benchmark to True. You can read more about this here.

# Enable cudnn benchmark
cudnn.benchmark = True

Reading and transforming the data

While the tutorial provides extensive information on loading, transforming, rescaling, cropping, and converting images to tensors using torch and torch.utils, we'll be using the torchvision package, which provides some frequently used data loaders and transforms out-of-the-box.

One of the things it assumes is that data is organized in a certain way. Navigate to the data folder to see how the data is structured. Within the directory called "data/sloths_versus_pain_au_chocolat", there are two folders called "train" and "val". Our dataset contains two labels:

  • sloth, and
  • pain_au_chocolat

so our folders are named and organized accordingly. Note that the images contained in the sloth, and pain_au_chocolat folders don't need to be named in any way, as long as the folders themselves are labelled correctly.

To adapt this tutorial to use different data, all you need to do is change the names of the sloth, and pain_au_chocolat folders, and upload different images into them.

When running code in notebooks, sometimes a file called .ipynb_checkpoints can show up in our training and validation folders. We'll remove these with the lines below.

# Banish pesky .ipynb files
!rm -R data/sloths_versus_pain_au_chocolat/train/.ipynb_checkpoints

!rm -R data/sloths_versus_pain_au_chocolat/val/.ipynb_checkpoints

We'll begin loading and transforming our data by defining the specific transforms we'd like to use from torchvision.

The specific transforms we'll use on our training set are:

  • RandomResizedCrop(), used to crop a random portion of an image and resize it to a given size, passed as the first argument to the function
  • RandomHorizontalFlip(), used to horizontally flip an image randomly with a given probability (default is 0.5)
  • ToTensor(), used to convert an image or numpy.ndarray to a tensor
  • Normalize(), used to normalize a tensor image with given means and standard deviations, passed as lists as the first and second arguments, respectively (taking tensors as input). If the images are similar to ImageNet images, we can use the mean and standard deviation of the ImageNet dataset. These are:
    • mean = [0.485, 0.456, 0.406]
    • std = [0.229, 0.224, 0.225].

The specific transforms we'll use on our validation set are:

  • Resize() used to resize an input to a given size, passed as the first argument
  • CenterCrop() to crop a given image at the center, based on dimensions provided in the first argument
  • ToTensor()
  • Normalize()
# Create data transforms
data_transforms = {
    'train' : transforms.Compose(
         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])            
    'val' : transforms.Compose(
         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])            

Next, we'll:

  • create a data directory path containing our dataset
  • pass our directory to datasets.ImageFolder() to create a data loader called image_datasets, where the images are arranged in the same way our folders are currently structured
  • use image_datasets to obtain our training and validation dataset_sizes and class_names
  • pass image_datasets to, which enables us to sample from our dataset, using
    • batch_size = 4, which uses 4 images per batch
    • shuffle = True, which will shuffle the data at every epoch
# Provide data directory
data_dir = 'data/sloths_versus_pain_au_chocolat'

# Create image folders for our training and validation data 
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                 for x in ['train', 'val']}

# Obtain dataset sizes from image_datasets
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}

# Obtain class_names from image_datasets
class_names = image_datasets['train'].classes

# Use image_datasets to sample from the dataset
dataloaders = {x:[x], 
                                             batch_size = 4,          
               for x in ['train', 'val']}
# Change selected device to CUDA, a parallel processing platform, if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Visualizing sloths and pastries with a custom function!

Defining the function

def imshow(inp, title=None):
    This function will make use of Matplotlib.pyplot's imshow() function for tensors. 
    It will show the same number of images as the batch we defined.
    # A transpose is required to get the images into the correct shape
    inp = inp.numpy().transpose((1, 2, 0)) 

    # Using default values for mean and std but can customize
    mean = np.array([0.485, 0.456, 0.406]) 
    std = np.array([0.229, 0.224, 0.225])
    # To visualize the correct colors     
    inp = std * inp + mean
    # To view a clipped version of an image             
    inp = np.clip(inp, 0, 1)
    # Visualize inp
    if title is not None: # Plot title goes here
    plt.pause(0.001)  # Enables the function to pause while the plots are updated

Calling our imshow() function

# Get a batch of training data
inputs, classes = next(iter(dataloaders['train']))

# Make a grid from batch
out = torchvision.utils.make_grid(inputs)

# Plot the grid with a title that concatenates all the class labels
imshow(out, title = [class_names[x] for x in classes])
# Get a batch of validation data
inputs, classes = next(iter(dataloaders['val']))

# Make a grid from batch
out = torchvision.utils.make_grid(inputs)

# Plot the grid with a title that concatenates all the class labels
imshow(out, title = [class_names[x] for x in classes])