Image Datasets, Dataloaders, and Transforms in Pytorch
Last Updated :
08 Jun, 2023
Deep learning in Pytorch is becoming increasingly popular due to its ease of use, support for multiple hardware platforms, and efficient processing. Image datasets, dataloaders, and transforms are essential components for achieving successful results with deep learning models using Pytorch.
In this article, we will discuss Image datasets, dataloaders, and transforms in Python using the Pytorch library. Image datasets store collections of images that can be used in deep-learning models for training, testing, or validation. These images are collected from a variety of sources such as online websites, physical controllers, user-generated content, etc. Dataloaders are responsible for loading the image datasets and providing them in batches to the models. Transforms are algorithms used to alter certain aspects of the images such as color, size, shape, brightness, etc. In Pytorch, these components can be used to create deep learning models for tasks such as object recognition, image classification, and image segmentation.
Popular datasets such as ImageNet, CIFAR-10, and MNIST can be used as the basis for creating image datasets and Dataloaders. Popular image transforms such as random rotation, random crop, random horizontal or vertical flipping, normalization, and color augmentation can be used to create model-ready data. Dataloaders can be used to efficiently load batches of data from the dataset for model training.
Image Datasets, Dataloaders, and Transforms
We will be implementing them on a sample dataset which can be downloaded from this link. You can download this dataset and follow along with this article to understand the concept better.
Import the necessary libraries
We will first import the libraries we will be using in this article.
Python3
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torchvision
|
Image Dataset
An image dataset can be created by defining the class which inherits the properties of torch.utils.data.Dataset class. This class has two abstract methods which have to be present in the derived class:
- __len__(): returns the number of samples present in the dataset.
- __getitem__(): returns the sample at the ith index from the dataset.
We can load the image dataset in Pytorch as follows:
Python3
class ImageDataset(torch.utils.data.Dataset):
def __init__( self , dir , transform = None ):
self .data_dir = dir
self .images = os.listdir( dir )
self .transform = transform
def __len__( self ):
return len ( self .images)
def __getitem__( self , index):
image_path = os.path.join( self .data_dir, self .images[index])
image = np.array(Image. open (image_path))
if self .transform:
image = self .transform(image)
return image
|
Now let us use this class on our sample dataset.
Python3
data_path = './maps/train'
dataset = ImageDataset(data_path)
dataset_length = len (dataset)
print ( 'Number of training examples:' ,dataset_length)
random_index = random.randint( 0 , dataset_length - 1 )
plt.imshow(dataset[random_index])
plt.show()
|
Output:
Number of training examples: 1096
Custom Transforms
A custom transform can be created by defining a class with a __call__() method. This transforms can be used for defining functions preprocessing and data augmentation. We can define a custom transform which performs preprocessing on the input image by splitting the image in two equal parts as follows:
Python3
class CustomTransform( object ):
def __init__( self , split_percent = 0.5 ):
self .split_percent = split_percent
def __call__( self , image):
split = int (image.shape[ 1 ] * self .split_percent)
image1 = image[:, :split, :]
image2 = image[:, split:, :]
return image1, image2
|
To use multiple transform objects in PyTorch, you can make use of the torchvision.transforms.Compose class. This class allows you to create an object that represents a composition of different transform objects while maintaining the order in which you want them to be applied.
Python3
transform = torchvision.transforms.Compose([
CustomTransform(),
])
|
Using this transform with the custom dataset class.
Python3
dataset = ImageDataset(data_path, transform = transform)
image, target = dataset[random_index]
plt.figure(figsize = ( 10 , 5 ))
plt.subplot( 1 , 2 , 1 )
plt.imshow(image)
plt.title( 'Image' )
plt.subplot( 1 , 2 , 2 )
plt.imshow(target)
plt.title( 'Target' )
plt.show()
|
Output:
Data augmentation
We can also define a transform to perform data augmentation. Data augmentation is a very useful tool when we have less dataset size and we want to increase the amount and diversity of data. Below is an example of a transform which performs random vertical flip and applies random color jittering to the input image.
Python3
class CustomAugmentation( object ):
def __init__( self , flip_prob = 0.5 , jitter_prob = 0.5 ):
self .flip_prob = flip_prob
self .jitter_prob = jitter_prob
def __call__( self , image):
if np.random.random() < self .flip_prob:
image = np.flip(image, axis = 1 )
if np.random.random() < self .jitter_prob:
image = np.array(image, dtype = np.int32)
image = image + np.random.randint( - 50 , 50 , size = image.shape, dtype = np.int32)
return image
|
Now we will define a transform based on the custom augmentation we defined earlier and display different variations of the target image.
Python3
aug_transform = torchvision.transforms.Compose([
CustomTransform(),
CustomAugmentation(),
])
nonaug_transform = torchvision.transforms.Compose([
CustomTransform(),
])
aug_dataset = ImageDataset(data_path, transform = aug_transform)
nonaug_dataset = ImageDataset(data_path, transform = nonaug_transform)
image, target = nonaug_dataset[random_index]
plt.figure(figsize = ( 10 , 10 ))
plt.subplot( 2 , 2 , 1 )
plt.imshow(target)
plt.title( 'Non augmented image' )
for i in range ( 2 , 5 ):
image, target1 = aug_dataset[random_index]
plt.subplot( 2 , 2 , i)
plt.imshow(target1)
plt.title( 'Augmented image' )
plt.show()
|
Output:
Custom Dataloaders
A custom dataloader can be defined by wrapping the dataset along with torch.utils.data.DataLoader class. It enable us to control various aspects of data loader like batch size, number of workers, and whether to shuffle the data or not. We can define a custom data loader in Pytorch as follows:
Python3
dataloader = torch.utils.data.DataLoader(
dataset = dataset,
batch_size = 4 ,
shuffle = True ,
num_workers = 2
)
print ( 'Number of batches:' , len (dataloader))
|
Output:
Number of batches: 274
Training and testing dataset
Now, we will combine all these knowledge and use to define train and test dataset. We will perform preprocessing on both dataset while we will only perform augmentation on train dataset. The Pytorch implementation is as follows:
Python3
train_path = f './maps/train'
test_path = f './maps/val'
train_transform = torchvision.transforms.Compose([
CustomTransform(),
CustomAugmentation(),
])
test_transform = torchvision.transforms.Compose([
CustomTransform(),
])
train_dataset = ImageDataset(train_path, transform = train_transform)
test_dataset = ImageDataset(test_path, transform = test_transform)
train_dataloader = torch.utils.data.DataLoader(
dataset = train_dataset,
batch_size = 4 ,
shuffle = True ,
num_workers = 2
)
test_dataloader = torch.utils.data.DataLoader(
dataset = test_dataset,
batch_size = 1 ,
shuffle = False ,
num_workers = 2
)
print ( 'Number of training batches:' , len (train_dataloader))
print ( 'Number of testing batches:' , len (test_dataloader))
|
Output:
Number of training batches: 274
Number of testing batches: 1098
Share your thoughts in the comments
Please Login to comment...