Open In App

Standard Deviation Across the Image Channels in PyTorch

Improve
Improve
Like Article
Like
Save
Share
Report

In Python, image processing and computer vision tasks often require the calculation of statistical metrics across the color channels of an image. The standard deviation, which measures how far apart values in a dataset are from the mean, is one such metric. In this article, we’ll look at how to use PyTorch to find the standard deviation across all image channels.

First, let’s talk about PyTorch‘s fundamentals. Deep learning activities are frequently carried out using PyTorch, a well-liked open-source machine learning library. It offers several helpful tools for working with picture data, as well as an effective method for creating and training neural networks. 

torch.std() method

PyTorch’s torch.std() function can be used to calculate the average standard deviation across all picture channels. This function determines a tensor’s standard deviation along a given axis.

Syntax of torch.std():

torch.std(input, dim, unbiased, keepdim=False, *, out=None)

Parameters:

  • input (Tensor) – the input tensor.
  • dim (int or tuple of ints) – the dimension or dimensions to reduce.
  • unbiased (bool) – whether to use Bessel’s correction (δN=1).
  • keepdim (bool) – whether the output tensor has dim retained or not.
  • out (Tensor, optional) – the output tensor.

Standard Deviation across the Image Channels in Pytorch

In the case of a color image with RGB channels, we want to find the standard deviation along the channel axis. To do this, we first need to reshape the image tensor to have the channel axis as the first dimension. We can do this using the permute() function in PyTorch.

Examples of Standard Deviation across the Image Channels

Let’s see a few examples of how can we find the Standard Deviation across the Image Channel using Pytorch in Python.

Example 1: RGB image

The example starts with importing necessary packages such as PyTorch, PIL, and transforms from torch-vision. It then defines the desired output size of the image. Next, an image is loaded using the PIL library and is resized using transforms. Resized from the torchvision.transforms module. Then the resized image is converted into a PyTorch tensor using transforms.ToTensor() from the same module. After that, the code calculates the standard deviation of each color channel separately using the torch.std(). The standard deviation is calculated for each channel by indexing the tensor along its first dimension: 0 for red, 1 for green, and 2 for blue. The standard deviation values for each channel are then printed using the print function, with the item() method to get the precise values instead of the whole tensor.

RGB gfg image

Python3




# import packages
import torch
from PIL import Image
import torchvision.transforms as transforms
 
# Define the desired output size
output_size = (32,32)
 
# Load an image using PIL
image = Image.open("gfg1.jpeg")
 
# # Resize the image using transforms.Resize
image = transforms.Resize(output_size)(image)
 
# Convert the image to a PyTorch tensor
image_tensor = transforms.ToTensor()(image)
print(image_tensor.shape)
# Get the standard deviation for each channel
red_std = torch.std(image_tensor[0])
green_std = torch.std(image_tensor[1])
blue_std = torch.std(image_tensor[2])
 
# Print the standard deviation of each channel
print('Red channel std dev:', red_std.item())
print('Green channel std dev:', green_std.item())
print('Blue channel std dev:', blue_std.item())


Output:

torch.Size([3, 32, 32])
Red channel std dev: 0.22655275464057922
Green channel std dev: 0.13615494966506958
Blue channel std dev: 0.20094075798988342

Example 2: Grayscale image:

When working with grayscale photos, the code needs to be altered in a few different ways. Grayscale images only contain one channel, so permuting the tensor dimensions is not necessary. Also, rather than being a tensor with one element for each channel, the standard deviation will be a single value.

To convert the image to grayscale, we call the convert(‘L’) method before calling Image.open(). The permute method call is therefore dropped because grayscale pictures only have one channel. The standard deviation for the entire tensor, which corresponds to the single channel in the grayscale image, is what we find finally.

Python3




import torch
from PIL import Image
import torchvision.transforms as transforms
 
# Define the desired output size
output_size = (32,32)
 
# Load an image using PIL
image = Image.open("gfg1.jpeg").convert('L') # convert to grayscale
 
# Resize the image using transforms.Resize
image = transforms.Resize(output_size)(image)
 
# Convert the image to a PyTorch tensor
image_tensor = transforms.ToTensor()(image)
print(image_tensor.shape)
 
# Find the standard deviation
std_dev = torch.std(image_tensor)
print('standard deviation : ')
print(std_dev)


Output:

torch.Size([1, 32, 32])
standard deviation : 
tensor(0.1696)


Last Updated : 08 Jun, 2023
Like Article
Save Article
Previous
Next
Share your thoughts in the comments
Similar Reads