Open In App

Image Datasets, Dataloaders, and Transforms in Pytorch

Last Updated : 08 Jun, 2023
Improve
Improve
Like Article
Like
Save
Share
Report

Deep learning in Pytorch is becoming increasingly popular due to its ease of use, support for multiple hardware platforms, and efficient processing. Image datasets, dataloaders, and transforms are essential components for achieving successful results with deep learning models using Pytorch.

In this article, we will discuss Image datasets, dataloaders, and transforms in Python using the Pytorch library. Image datasets store collections of images that can be used in deep-learning models for training, testing, or validation. These images are collected from a variety of sources such as online websites, physical controllers, user-generated content, etc. Dataloaders are responsible for loading the image datasets and providing them in batches to the models. Transforms are algorithms used to alter certain aspects of the images such as color, size, shape, brightness, etc. In Pytorch, these components can be used to create deep learning models for tasks such as object recognition, image classification, and image segmentation.

Popular datasets such as ImageNet, CIFAR-10, and MNIST can be used as the basis for creating image datasets and Dataloaders. Popular image transforms such as random rotation, random crop, random horizontal or vertical flipping, normalization, and color augmentation can be used to create model-ready data. Dataloaders can be used to efficiently load batches of data from the dataset for model training.

Image Datasets, Dataloaders, and Transforms

We will be implementing them on a sample dataset which can be downloaded from this link. You can download this dataset and follow along with this article to understand the concept better.

Import the necessary libraries

We will first import the libraries we will be using in this article.

Python3




import os
import numpy as np
from PIL import Image
  
import matplotlib.pyplot as plt
  
import torch
import torchvision


Image Dataset

An image dataset can be created by defining the class which inherits the properties of torch.utils.data.Dataset class. This class has two abstract methods which have to be present in the derived class:

  • __len__(): returns the number of samples present in the dataset.
  • __getitem__(): returns the sample at the ith index from the dataset.

We can load the image dataset in Pytorch as follows:

Python3




# Creating a custom dataset class
class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, dir, transform=None):
        self.data_dir = dir
        self.images = os.listdir(dir)
        self.transform = transform
  
    # Defining the length of the dataset
    def __len__(self):
        return len(self.images)
  
    # Defining the method to get an item from the dataset
    def __getitem__(self, index):
        image_path = os.path.join(self.data_dir, self.images[index])
        image = np.array(Image.open(image_path))
  
        # Applying the transform
        if self.transform:
            image = self.transform(image)
          
        return image


Now let us use this class on our sample dataset.

Python3




# Replace the path with the path to your dataset
data_path = './maps/train'
  
# Creating a dataset object with the path to the dataset
dataset = ImageDataset(data_path)
  
# Getting the length of the dataset
dataset_length = len(dataset)
  
# Printing the length of the dataset
print('Number of training examples:',dataset_length)
  
# Generating a random index within the dataset length
random_index = random.randint(0, dataset_length - 1)
  
# Plotting the randomly selected image
plt.imshow(dataset[random_index])
plt.show()


Output:

Number of training examples: 1096

img-1-(1).pngCustom Transforms

A custom transform can be created by defining a class with a __call__() method. This transforms can be used for defining functions preprocessing and data augmentation. We can define a custom transform which performs preprocessing on the input image by splitting the image in two equal parts as follows:

Python3




# Defining a custom transformer class
class CustomTransform(object):
    def __init__(self, split_percent=0.5):
        self.split_percent = split_percent
      
    # Defining the transform method
    def __call__(self, image):
        # Splitting the image into two parts
        split = int(image.shape[1] * self.split_percent)
        image1 = image[:, :split, :]
        image2 = image[:, split:, :]
          
        # Returning the two parts of the image
        return image1, image2


To use multiple transform objects in PyTorch, you can make use of the torchvision.transforms.Compose class. This class allows you to create an object that represents a composition of different transform objects while maintaining the order in which you want them to be applied.

Python3




# Defining a composition of transforms
transform = torchvision.transforms.Compose([
    # Replace with the transform object(s)
    CustomTransform(),
])


Using this transform with the custom dataset class.

Python3




# Creating a dataset object with transforms
dataset = ImageDataset(data_path, transform=transform)
  
# Get the first splited image from the dataset
image, target = dataset[random_index]
  
# Plotting the image and target
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(image)
plt.title('Image')
plt.subplot(1, 2, 2)
plt.imshow(target)
plt.title('Target')
plt.show()


Output:

img-2-(1).png

Data augmentation

We can also define a transform to perform data augmentation. Data augmentation is a very useful tool when we have less dataset size and we want to increase the amount and diversity of data. Below is an example of a transform which performs random vertical flip and applies random color jittering to the input image.

Python3




# Defining a custom augmentation class
class CustomAugmentation(object):
    def __init__(self, flip_prob=0.5, jitter_prob=0.5):
        self.flip_prob = flip_prob
        self.jitter_prob = jitter_prob
  
    # Defining the transform method
    def __call__(self, image):
        # Flipping the image horizontally
        if np.random.random() < self.flip_prob:
            image = np.flip(image, axis=1)
          
        # Applying random color jitter by adding random noise
        if np.random.random() < self.jitter_prob:
            image = np.array(image, dtype=np.int32)
            # Adding random noise to the image
            image = image + np.random.randint(-50, 50, size=image.shape, dtype=np.int32)
          
        # Returning the augmented image
        return image


Now we will define a transform based on the custom augmentation we defined earlier and display different variations of the target image.

Python3




# Defining augmented and non-augmented transforms
aug_transform = torchvision.transforms.Compose([
    # Replace with the transform object
    CustomTransform(),
    CustomAugmentation(),
])
nonaug_transform = torchvision.transforms.Compose([
    # Replace with the transform object
    CustomTransform(),
])
  
# Creating a dataset object with augmented and non-augmented transforms
aug_dataset = ImageDataset(data_path, transform=aug_transform)
nonaug_dataset = ImageDataset(data_path, transform=nonaug_transform)
  
# Displaying a non augmented images from the dataset and its augmented version
image, target = nonaug_dataset[random_index]
  
# Creating a plot
plt.figure(figsize=(10, 10))
# Adding the non augmented image
plt.subplot(2, 2, 1)
plt.imshow(target)
plt.title('Non augmented image')
  
# Adding the augmented images
for i in range(2, 5):
    image, target1 = aug_dataset[random_index]
    plt.subplot(2, 2, i)
    plt.imshow(target1)
    plt.title('Augmented image')
  
# Displaying the plot
plt.show()


Output:

augmented-and-non-augmented-(1)-(1)-(1).png

Custom Dataloaders

A custom dataloader can be defined by wrapping the dataset along with torch.utils.data.DataLoader class. It enable us to control various aspects of data loader like batch size, number of workers, and whether to shuffle the data or not. We can define a custom data loader in Pytorch as follows:

Python3




# Defining a custom data loader
dataloader = torch.utils.data.DataLoader(
    # Replace with the dataset object
    dataset=dataset,
      
    # Defining the batch size
    batch_size=4,
      
    # If true, shuffles the dataset at every epoch
    shuffle=True,
  
    # Number of parallel processes for loading the data
    num_workers=2
)
  
# Get the length of the dataloader 
# (Number of batches)
print('Number of batches:',len(dataloader))


Output:

Number of batches: 274

Training and testing dataset

Now, we will combine all these knowledge and use to define train and test dataset. We will perform preprocessing on both dataset while we will only perform augmentation on train dataset. The Pytorch implementation is as follows:

Python3




# File path to the dataset, replace this with your path
train_path = f'./maps/train'
test_path = f'./maps/val'
  
  
# Defining the train and test transforms
train_transform = torchvision.transforms.Compose([
    CustomTransform(),
    CustomAugmentation(),
])
test_transform = torchvision.transforms.Compose([
    CustomTransform(),
])
  
  
# Creating the train and test datasets
train_dataset = ImageDataset(train_path, transform=train_transform)
test_dataset = ImageDataset(test_path, transform=test_transform)
  
  
# Creating the train and test dataloaders
train_dataloader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=4,
    shuffle=True,
    num_workers=2
)
test_dataloader = torch.utils.data.DataLoader(
    dataset=test_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=2
)
  
# Printing the length of the train and test dataloaders
print('Number of training batches:',len(train_dataloader))
print('Number of testing batches:',len(test_dataloader))


Output:

Number of training batches: 274
Number of testing batches: 1098


Like Article
Suggest improvement
Share your thoughts in the comments

Similar Reads