Open In App

Load a Computer Vision Dataset in PyTorch

Improve
Improve
Like Article
Like
Save
Share
Report

Computer vision is a subset of Artificial Intelligence that gives the ability to the computer to understand images. In Deep Learning, Convolution Neural Network is used to process the image. For building the good we need a lot of images to process.

There are several ways to load a computer vision dataset in PyTorch, depending on the format of the dataset and the specific requirements of your project.

One popular method is to use the built-in PyTorch dataset classes, such as  torchvision.datasets.’It provides a convenient way to load and preprocess common computer vision datasets, such as CIFAR-10 and ImageNet. For example, to load the CIFAR-10 dataset, you can use the following code:

Python3




# Import the necessary library
import torchvision.datasets as datasets
 
 
# Download the cifar Dataset
cifar10_train = datasets.CIFAR10(root="./data", train=True, download=True)
cifar10_test = datasets.CIFAR10(root="./data", train=False, download=True)


Output:

CIFAR-10 -Geeksforgeeks

CIFAR-10

The code above will download the CIFAR-10 dataset and save it in the ‘./data directory.

Another method is using the torch.utils.data.DataLoader class to load the data. This is more useful when the data is in your local machine and you would like to have the power of data augmentation and the ability to shuffle the data and also have the ability to specify the batch size. it has the advantages of customizing data loading order, batching, single or multi-process data loading, etc.

Here we can use transform.Compose function from torchvision to rotate, flip, normalize and convert it into tensor form from the image.

Python3




# Import the necessary library
from torchvision import transforms
from torch.utils.data import DataLoader
 
# Image Transformation
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize([0.35, 0.35, 0.406], [0.30, 0.34, 0.35])
])
 
# Load the dataset with transformation
cifar10_train = datasets.CIFAR10(root="./data", train=True, download=False, transform=transform)
cifar10_test = datasets.CIFAR10(root="./data", train=False, download=False, transform=transform)
 
# Make the batch of size 16
train_loader = DataLoader(cifar10_train, batch_size=32, shuffle=True, num_workers=2)
test_loader = DataLoader(cifar10_test, batch_size=32, shuffle=False, num_workers=2)


View the train and test  data

Python3




#Train Dataset
print(train_loader.dataset)
#Test Dataset
print(test_loader.dataset)


Output:

Dataset CIFAR10
    Number of datapoints: 50000
    Root location: ./data
    Split: Train
    StandardTransform
Transform: Compose(
               RandomHorizontalFlip(p=0.5)
               RandomRotation(degrees=[-10.0, 10.0], interpolation=nearest, expand=False, fill=0)
               ToTensor()
               Normalize(mean=[0.35, 0.35, 0.406], std=[0.3, 0.34, 0.35])
           )
Dataset CIFAR10
    Number of datapoints: 10000
    Root location: ./data
    Split: Test
    StandardTransform
Transform: Compose(
               RandomHorizontalFlip(p=0.5)
               RandomRotation(degrees=[-10.0, 10.0], interpolation=nearest, expand=False, fill=0)
               ToTensor()
               Normalize(mean=[0.35, 0.35, 0.406], std=[0.3, 0.34, 0.35])
           )

Plot the image:

Python3




# Iteration
inputs, Class = next(iter(train_loader))
 
#Define the class names
class_name ={0:'airplane',
             1:'automobile',
             2:'bird',
             3:'cat',
             4:'deer',
             5:'dog',
             6:'frog',
             7:'horse',
             8:'ship',
             9:'truck'
            }
 
#Plot the figure
plt.figure(figsize=(30,16), dpi=1000)
for i in range(32):
    plt.subplot(4,8,i+1)
    plt.imshow(inputs[i].numpy().transpose((1, 2, 0)))
    plt.axis('off')
    plt.title(class_name[int(Class[i])])
     
plt.show()


Output:

CIFAR-10 -Geeksforgeeks

CIFAR-10

The other libraries like ‘albumentations‘ , can be used to load the dataset and preprocess the data. It all depends on the format of your data and what you are trying to achieve

You might also want to check the version of PyTorch you’re using, as well as the format of the dataset you’re trying to load. Some datasets might be in a custom format and you might need to write your own code to load it correctly.



Last Updated : 19 Apr, 2023
Like Article
Save Article
Previous
Next
Share your thoughts in the comments
Similar Reads