Open In App

Apply torch.inverse() Function of PyTorch to Every Sample in the Batch

Last Updated : 21 Mar, 2023
Improve
Improve
Like Article
Like
Save
Share
Report

PyTorch is a deep learning framework that provides a variety of functions to perform different operations on tensors. One such function is torch.inverse(), which can be used to compute the inverse of a square matrix.

Sometimes we may have a batch of matrices, where each matrix represents some data that we want to process using deep learning. In such cases, we may want to apply the torch.inverse() function to each matrix in the batch. We can use PyTorch’s broadcasting feature which provides a facility to apply the same operation to all the elements in a tensor. It creates a new tensor very similar to the input tensor. The difference is that each element in the new tensor is the inverse of the corresponding element in the input tensor.

The following code demonstrates how to apply the torch.inverse() function to every sample in a batch. We first create a batch of matrices and then use the torch.inverse() function to find the inverse of each matrix in the batch.

Syntax of torch.inverse():

It takes the inverse of the square matrix input. input can be batches of 2D square tensors, in which case this function would return a tensor composed of individual inverses.

Syntax: torch.inverse(input, *, out=None)

Parameters:

  • input (Tensor) – the input tensor of size  (∗,n,n) where * is zero or more batch dimensions
  • Keyword Arguments
  • out (Tensor, optional) – the output tensor.

Example 1:

Suppose we have a batch of 2 matrices, where each matrix has a shape (3, 3). We can create this batch using the torch.randn() function.We can then apply the torch.inverse() function to the entire input tensor, which computes the inverse of each 3×3 matrix in the batch. The resulting output tensor also has shape (2, 3, 3), where each 3×3 matrix is the inverse of the corresponding matrix in the input tensor.

Python3




import torch
  
# Create a batch of 2 matrices with shape (2, 3, 3)
batch_size = 2
input_tensor = torch.randn(batch_size, 3, 3)
  
# Compute the inverse of each matrix in the batch
output_tensor = torch.inverse(input_tensor)
  
# Print the input and output tensors
print("Input tensor:")
print(input_tensor)
print("Output tensor:")
print(output_tensor)


Output:

Input tensor:
tensor([[[-0.9808, -1.5437,  1.1773],
         [-0.8945, -1.2584,  1.6299],
         [ 0.8855,  0.3088, -1.4001]],

        [[ 0.4860, -0.8735, -1.1052],
         [-0.4737, -2.8350,  0.1861],
         [ 1.7559, -0.4935,  0.7353]]])
Output tensor:
tensor([[[-2.3209,  3.3154,  1.9079],
         [-0.3517, -0.6101, -1.0059],
         [-1.5453,  1.9621,  0.2705]],

        [[ 0.2723, -0.1623,  0.4503],
         [-0.0923, -0.3140, -0.0592],
         [-0.7122,  0.1768,  0.2448]]])

Example 2:

In this example, we use the torch.randn() function to generate a set of three matrices with the shape (3, 2, 2). The torch.ones() function is then used to produce a tensor of ones with the shape (batch size, 1, 1). Using element-wise multiplication, we can utilize this tensor of ones to apply the torch.inverse() function to each matrix in the batch. Every 2×2 matrix in the resulting output tensor is the inverse of its corresponding matrix in the input tensor, and it also has a shape (3, 2, 2).

Python3




import torch
  
# Create a batch of 3 matrices with shape (3, 2, 2)
batch_size = 3
input_tensor = torch.randn(batch_size, 2, 2)
  
# Create a tensor of ones with shape (batch_size, 1, 1)
ones = torch.ones(batch_size, 1, 1)
  
# Compute the inverse of each matrix in the batch
output_tensor = input_tensor.inverse() * ones
  
# Print the input and output tensors
print("Input tensor:")
print(input_tensor)
print("Output tensor:")
print(output_tensor)


Output:

Input tensor:
tensor([[[-0.1727, -0.5076],
         [ 0.9635,  0.0972]],

        [[ 1.7375,  1.6074],
         [ 0.0697, -0.8704]],

        [[-0.6624,  1.8799],
         [ 1.1704, -0.1165]]])
Output tensor:
tensor([[[ 0.2057,  1.0747],
         [-2.0400, -0.3656]],

        [[ 0.5358,  0.9895],
         [ 0.0429, -1.0697]],

        [[ 0.0549,  0.8855],
         [ 0.5513,  0.3120]]])


Like Article
Suggest improvement
Previous
Next
Share your thoughts in the comments

Similar Reads