Load a Computer Vision Dataset in PyTorch
Computer vision is a subset of Artificial Intelligence that gives the ability to the computer to understand images. In Deep Learning, Convolution Neural Network is used to process the image. For building the good we need a lot of images to process.
There are several ways to load a computer vision dataset in PyTorch, depending on the format of the dataset and the specific requirements of your project.
One popular method is to use the built-in PyTorch dataset classes, such as torchvision.datasets.’It provides a convenient way to load and preprocess common computer vision datasets, such as CIFAR-10 and ImageNet. For example, to load the CIFAR-10 dataset, you can use the following code:
Python3
import torchvision.datasets as datasets
cifar10_train = datasets.CIFAR10(root = "./data" , train = True , download = True )
cifar10_test = datasets.CIFAR10(root = "./data" , train = False , download = True )
|
Output:
CIFAR-10
The code above will download the CIFAR-10 dataset and save it in the ‘./data‘ directory.
Another method is using the ‘torch.utils.data.DataLoader class to load the data. This is more useful when the data is in your local machine and you would like to have the power of data augmentation and the ability to shuffle the data and also have the ability to specify the batch size. it has the advantages of customizing data loading order, batching, single or multi-process data loading, etc.
Here we can use transform.Compose function from torchvision to rotate, flip, normalize and convert it into tensor form from the image.
Python3
from torchvision import transforms
from torch.utils.data import DataLoader
transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomRotation( 10 ),
transforms.ToTensor(),
transforms.Normalize([ 0.35 , 0.35 , 0.406 ], [ 0.30 , 0.34 , 0.35 ])
])
cifar10_train = datasets.CIFAR10(root = "./data" , train = True , download = False , transform = transform)
cifar10_test = datasets.CIFAR10(root = "./data" , train = False , download = False , transform = transform)
train_loader = DataLoader(cifar10_train, batch_size = 32 , shuffle = True , num_workers = 2 )
test_loader = DataLoader(cifar10_test, batch_size = 32 , shuffle = False , num_workers = 2 )
|
View the train and test data
Python3
print (train_loader.dataset)
print (test_loader.dataset)
|
Output:
Dataset CIFAR10
Number of datapoints: 50000
Root location: ./data
Split: Train
StandardTransform
Transform: Compose(
RandomHorizontalFlip(p=0.5)
RandomRotation(degrees=[-10.0, 10.0], interpolation=nearest, expand=False, fill=0)
ToTensor()
Normalize(mean=[0.35, 0.35, 0.406], std=[0.3, 0.34, 0.35])
)
Dataset CIFAR10
Number of datapoints: 10000
Root location: ./data
Split: Test
StandardTransform
Transform: Compose(
RandomHorizontalFlip(p=0.5)
RandomRotation(degrees=[-10.0, 10.0], interpolation=nearest, expand=False, fill=0)
ToTensor()
Normalize(mean=[0.35, 0.35, 0.406], std=[0.3, 0.34, 0.35])
)
Plot the image:
Python3
inputs, Class = next ( iter (train_loader))
class_name = { 0 : 'airplane' ,
1 : 'automobile' ,
2 : 'bird' ,
3 : 'cat' ,
4 : 'deer' ,
5 : 'dog' ,
6 : 'frog' ,
7 : 'horse' ,
8 : 'ship' ,
9 : 'truck'
}
plt.figure(figsize = ( 30 , 16 ), dpi = 1000 )
for i in range ( 32 ):
plt.subplot( 4 , 8 ,i + 1 )
plt.imshow(inputs[i].numpy().transpose(( 1 , 2 , 0 )))
plt.axis( 'off' )
plt.title(class_name[ int (Class[i])])
plt.show()
|
Output:
CIFAR-10
The other libraries like ‘albumentations‘ , can be used to load the dataset and preprocess the data. It all depends on the format of your data and what you are trying to achieve
You might also want to check the version of PyTorch you’re using, as well as the format of the dataset you’re trying to load. Some datasets might be in a custom format and you might need to write your own code to load it correctly.
Last Updated :
19 Apr, 2023
Like Article
Save Article
Share your thoughts in the comments
Please Login to comment...