Open In App

Converting an image to a Torch Tensor in Python

In this article, we will see how to convert an image to a PyTorch Tensor. A tensor in PyTorch is like a NumPy array containing elements of the same dtypes.  

A tensor may be of scalar type, one-dimensional or multi-dimensional. To convert an image to a tensor in PyTorch we use PILToTensor() and ToTensor() transforms. These transforms are provided in the torchvision.transforms package. Using these transforms we can convert a PIL image or a numpy.ndarray. The numpy.ndarray must be in [H, W, C] format, where H, W, and C are the height, width, and a number of channels of the image.



transform = transforms.Compose([transforms.PILToTensor()])

tensor = transform(img)



This transform converts a PIL image to a tensor of data type torch.uint8 in the range between 0 and 255. Here img is a PIL image.

transform = transforms.Compose([transforms.ToTensor()])

tensor = transform(img)

This transform converts any numpy.ndarray to torch tensor of data type torch.float32 in range 0 and 1. Here img is a numpy.ndarray.

Approach:

The below image is used as an input image in both examples:

Example 1:

In the below example, we convert a PIL image to Torch Tensor. 




# Import necessary libraries
import torch
from PIL import Image
import torchvision.transforms as transforms
  
# Read a PIL image
image = Image.open('iceland.jpg')
  
# Define a transform to convert PIL 
# image to a Torch tensor
transform = transforms.Compose([
    transforms.PILToTensor()
])
  
# transform = transforms.PILToTensor()
# Convert the PIL image to Torch tensor
img_tensor = transform(image)
  
# print the converted Torch tensor
print(img_tensor)

Output:

Notice that the data type of the output tensor is torch.uint8 and the values are in range [0,255].

Example 2:

In this example, we read an RGB image using OpenCV. The type of image read using OpenCV is numpy.ndarray. We convert it to a torch tensor using the transform ToTensor()




# Import required libraries
import torch
import cv2
import torchvision.transforms as transforms
  
# Read the image
image = cv2.imread('iceland.jpg')
  
# Convert BGR image to RGB image
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  
# Define a transform to convert
# the image to torch tensor
transform = transforms.Compose([
    transforms.ToTensor()
])
  
# Convert the image to Torch tensor
tensor = transform(image)
  
# print the converted image tensor
print(tensor)

Output:

Notice that the data type of the output tensor is torch.float32 and the values are in the range [0, 1].


Article Tags :