Datasets And Dataloaders in Pytorch
PyTorch is a Python library developed by Facebook to run and train machine learning and deep learning models. Training a deep learning model requires us to convert the data into the format that can be processed by the model. PyTorch provides the torch.utils.data library to make data loading easy with DataSets and Dataloader class.
Dataset is itself the argument of DataLoader constructor which indicates a dataset object to load from. There are two types of datasets:
- map-style datasets: This data set provides two functions __getitem__( ), __len__( ) that returns the indices of the sample data referred to and the numbers of samples respectively. In the example, we will use this type of dataset.
- iterable-style datasets: Datasets that can be represented in a set of iterable data samples, for this we use __iter__( )function.
Dataloader on the other hand, not only allows us to iterate through the dataset in batches but also gives us access to inbuilt functions for multiprocessing(allows us to load multiple batches of data in parallel, rather than loading one batch at a time), shuffling, etc.
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, *, prefetch_factor=2, persistent_workers=False)
Dataset Used: heart
Let us deal with an example so that the concept becomes clearer.
First import all required libraries and the dataset to work with. Load dataset in torch tensors which are accessed through __getitem__( ) protocol, to get the index of the particular dataset. Then we unpack the data and print corresponding features and labels.
tensor([ 63.0000, 1.0000, 3.0000, 145.0000, 233.0000, 1.0000, 0.0000,
150.0000, 0.0000, 2.3000, 0.0000, 0.0000, 1.0000]) tensor([1.])
The torch dataLoader takes this dataset as input, along with other arguments for batch_size, shuffle, etc, calculate nums_samples per batch, then print out the targets and labels in batches.
We now train the data by first looping over the epoch and then over samples after that printing out the number of epochs, input tensor and label tensor with each iteration.