Open In App

Tensor Operations in PyTorch

Improve
Improve
Like Article
Like
Save
Share
Report

In this article, we will discuss tensor operations in PyTorch.

PyTorch is a scientific package used to perform operations on the given data like tensor in python. A Tensor is a collection of data like a numpy array. We can create a tensor using the tensor function:

Syntax: torch.tensor([[[element1,element2,.,element n],……,[element1,element2,.,element n]]])

where,

  • torch is the module
  • tensor is the function
  • elements are the data

The Operations in PyTorch that are applied on tensor are:

expand()

This operation is used to expand the tensor into a number of tensors, a number of rows in tensors, and a number of columns in tensors.

Syntax: tensor.expand(n,r,c)

where,

  • tensor is the input tensor
  • n is to return the number of tensors
  • r  is the number of rows in each tensor
  • c is the number of columns in each tensor

Example: In this example, we will expand the tensor into 4 tensors, 2 rows and 3 columns in each tensor

Python3




# import module
import torch
  
# create a tensor with 2 data
# in 3 three elements each
data = torch.tensor([[10, 20, 30], 
                     [45, 67, 89]])
  
# display
print(data)
  
# expand the tensor into 4 tensors , 2
# rows and 3 columns in each tensor
print(data.expand(4, 2, 3))


Output:

tensor([[10, 20, 30],
        [45, 67, 89]])
tensor([[[10, 20, 30],
         [45, 67, 89]],

        [[10, 20, 30],
         [45, 67, 89]],

        [[10, 20, 30],
         [45, 67, 89]],

        [[10, 20, 30],
         [45, 67, 89]]])

permute()

This is used to reorder the tensor using row and column

Syntax: tensor.permute(a,b,c)

where

  • tensor is the input tensor
  • permute(1,2,0) is used to permute the tensor by row
  • permute(2,1,0) is used to permute the tensor by column

Example: In this example, we are going to permute the tensor first by row and by column.

Python3




# import module
import torch
  
# create a tensor with 2 data
# in 3 three elements each
data = torch.tensor([[[10, 20, 30], 
                      [45, 67, 89]]])
  
# display
print(data)
  
# permute the tensor first by row
print(data.permute(1, 2, 0))
  
# permute the tensor first by column
print(data.permute(2, 1, 0))


Output:

tensor([[[10, 20, 30],
         [45, 67, 89]]])
tensor([[[10],
         [20],
         [30]],

        [[45],
         [67],
         [89]]])
tensor([[[10],
         [45]],

        [[20],
         [67]],

        [[30],
         [89]]])

tolist()

This method is used to return a list or nested list from the given tensor.

Syntax: tensor.tolist()

Example: In this example, we are going to convert the given tensor into the list.

Python3




# import module
import torch
  
# create a tensor with 2 data in
# 3 three elements each
data = torch.tensor([[[10, 20, 30], 
                      [45, 67, 89]]])
  
# display
print(data)
  
# convert the tensor to list
print(data.tolist())


Output:

tensor([[[10, 20, 30],
         [45, 67, 89]]])
[[[10, 20, 30], [45, 67, 89]]]

narrow()

This function is used to narrow the tensor. in other words, it will extend the tensor based on the input dimensions.

Syntax: torch.narrow(tensor,d,i,l)

where,

  • tensor is the input tensor
  • d is the dimension to narrow
  • i is the starting index of the vector
  • l is the length of the new tensor along the dimension – d

Example: In this example, we will narrow the tensor with 1 dimension which is starting from 1 st index, and the length of each dimension is 2 and we will narrow the tensor with 1 dimension which is starting from the 0th index and the length of each dimension is 2

Python3




# import module
import torch
  
# create a tensor with 2 data in
# 3 three elements each
data = torch.tensor([[10, 20, 30], 
                     [45, 67, 89], 
                     [23, 45, 67]])
  
# display
print(data)
  
# narrow the tensor
# with 1 dimension
# starting from 1 st index
# length of each dimension is 2
print(torch.narrow(data, 1, 1, 2))
  
# narrow the tensor
# with 1 dimension
# starting from 0 th  index
# length of each dimension is 2
print(torch.narrow(data, 1, 0, 2))


Output:

tensor([[10, 20, 30],
        [45, 67, 89],
        [23, 45, 67]])
tensor([[20, 30],
        [67, 89],
        [45, 67]])
tensor([[10, 20],
        [45, 67],
        [23, 45]])

where()

This function is used to return the new tensor by checking the existing tensors conditionally.

Syntax: torch.where(condition,statement1,statement2)

where,

  • condition is used to check the existing tensor condition by applying conditions on the existing tensors
  • statememt1 is executed when condition is true
  • statememt2 is executed when condition is false

Example: We will use different relational operators to check the functionality

Python3




# import module
import torch
  
# create a tensor with 3 data in
# 3 three elements each
data = torch.tensor([[[10, 20, 30], 
                      [45, 67, 89],
                      [23, 45, 67]]])
  
# display
print(data)
  
# set the number 100 when the
# number in greater than 45
# otherwise 50
print(torch.where(data > 45, 100, 50))
  
# set the number 100 when the
# number in less than 45
# otherwise 50
print(torch.where(data < 45, 100, 50))
  
# set the number 100 when the number in 
# equal to 23 otherwise 50
print(torch.where(data == 23, 100, 50))


Output:

tensor([[[10, 20, 30],
         [45, 67, 89],
         [23, 45, 67]]])
tensor([[[ 50,  50,  50],
         [ 50, 100, 100],
         [ 50,  50, 100]]])
tensor([[[100, 100, 100],
         [ 50,  50,  50],
         [100,  50,  50]]])
tensor([[[ 50,  50,  50],
         [ 50,  50,  50],
         [100,  50,  50]]])


Last Updated : 04 Jan, 2022
Like Article
Save Article
Previous
Next
Share your thoughts in the comments
Similar Reads