Skip to content
Related Articles
Get the best out of our app
GeeksforGeeks App
Open App
geeksforgeeks
Browser
Continue

Related Articles

Tensor Operations in PyTorch

Improve Article
Save Article
Like Article
Improve Article
Save Article
Like Article

In this article, we will discuss tensor operations in PyTorch.

PyTorch is a scientific package used to perform operations on the given data like tensor in python. A Tensor is a collection of data like a numpy array. We can create a tensor using the tensor function:

Syntax: torch.tensor([[[element1,element2,.,element n],……,[element1,element2,.,element n]]])

where,

  • torch is the module
  • tensor is the function
  • elements are the data

The Operations in PyTorch that are applied on tensor are:

expand()

This operation is used to expand the tensor into a number of tensors, a number of rows in tensors, and a number of columns in tensors.

Syntax: tensor.expand(n,r,c)

where,

  • tensor is the input tensor
  • n is to return the number of tensors
  • r  is the number of rows in each tensor
  • c is the number of columns in each tensor

Example: In this example, we will expand the tensor into 4 tensors, 2 rows and 3 columns in each tensor

Python3




# import module
import torch
  
# create a tensor with 2 data
# in 3 three elements each
data = torch.tensor([[10, 20, 30], 
                     [45, 67, 89]])
  
# display
print(data)
  
# expand the tensor into 4 tensors , 2
# rows and 3 columns in each tensor
print(data.expand(4, 2, 3))

Output:

tensor([[10, 20, 30],
        [45, 67, 89]])
tensor([[[10, 20, 30],
         [45, 67, 89]],

        [[10, 20, 30],
         [45, 67, 89]],

        [[10, 20, 30],
         [45, 67, 89]],

        [[10, 20, 30],
         [45, 67, 89]]])

permute()

This is used to reorder the tensor using row and column

Syntax: tensor.permute(a,b,c)

where

  • tensor is the input tensor
  • permute(1,2,0) is used to permute the tensor by row
  • permute(2,1,0) is used to permute the tensor by column

Example: In this example, we are going to permute the tensor first by row and by column.

Python3




# import module
import torch
  
# create a tensor with 2 data
# in 3 three elements each
data = torch.tensor([[[10, 20, 30], 
                      [45, 67, 89]]])
  
# display
print(data)
  
# permute the tensor first by row
print(data.permute(1, 2, 0))
  
# permute the tensor first by column
print(data.permute(2, 1, 0))

Output:

tensor([[[10, 20, 30],
         [45, 67, 89]]])
tensor([[[10],
         [20],
         [30]],

        [[45],
         [67],
         [89]]])
tensor([[[10],
         [45]],

        [[20],
         [67]],

        [[30],
         [89]]])

tolist()

This method is used to return a list or nested list from the given tensor.

Syntax: tensor.tolist()

Example: In this example, we are going to convert the given tensor into the list.

Python3




# import module
import torch
  
# create a tensor with 2 data in
# 3 three elements each
data = torch.tensor([[[10, 20, 30], 
                      [45, 67, 89]]])
  
# display
print(data)
  
# convert the tensor to list
print(data.tolist())

Output:

tensor([[[10, 20, 30],
         [45, 67, 89]]])
[[[10, 20, 30], [45, 67, 89]]]

narrow()

This function is used to narrow the tensor. in other words, it will extend the tensor based on the input dimensions.

Syntax: torch.narrow(tensor,d,i,l)

where,

  • tensor is the input tensor
  • d is the dimension to narrow
  • i is the starting index of the vector
  • l is the length of the new tensor along the dimension – d

Example: In this example, we will narrow the tensor with 1 dimension which is starting from 1 st index, and the length of each dimension is 2 and we will narrow the tensor with 1 dimension which is starting from the 0th index and the length of each dimension is 2

Python3




# import module
import torch
  
# create a tensor with 2 data in
# 3 three elements each
data = torch.tensor([[10, 20, 30], 
                     [45, 67, 89], 
                     [23, 45, 67]])
  
# display
print(data)
  
# narrow the tensor
# with 1 dimension
# starting from 1 st index
# length of each dimension is 2
print(torch.narrow(data, 1, 1, 2))
  
# narrow the tensor
# with 1 dimension
# starting from 0 th  index
# length of each dimension is 2
print(torch.narrow(data, 1, 0, 2))

Output:

tensor([[10, 20, 30],
        [45, 67, 89],
        [23, 45, 67]])
tensor([[20, 30],
        [67, 89],
        [45, 67]])
tensor([[10, 20],
        [45, 67],
        [23, 45]])

where()

This function is used to return the new tensor by checking the existing tensors conditionally.

Syntax: torch.where(condition,statement1,statement2)

where,

  • condition is used to check the existing tensor condition by applying conditions on the existing tensors
  • statememt1 is executed when condition is true
  • statememt2 is executed when condition is false

Example: We will use different relational operators to check the functionality

Python3




# import module
import torch
  
# create a tensor with 3 data in
# 3 three elements each
data = torch.tensor([[[10, 20, 30], 
                      [45, 67, 89],
                      [23, 45, 67]]])
  
# display
print(data)
  
# set the number 100 when the
# number in greater than 45
# otherwise 50
print(torch.where(data > 45, 100, 50))
  
# set the number 100 when the
# number in less than 45
# otherwise 50
print(torch.where(data < 45, 100, 50))
  
# set the number 100 when the number in 
# equal to 23 otherwise 50
print(torch.where(data == 23, 100, 50))

Output:

tensor([[[10, 20, 30],
         [45, 67, 89],
         [23, 45, 67]]])
tensor([[[ 50,  50,  50],
         [ 50, 100, 100],
         [ 50,  50, 100]]])
tensor([[[100, 100, 100],
         [ 50,  50,  50],
         [100,  50,  50]]])
tensor([[[ 50,  50,  50],
         [ 50,  50,  50],
         [100,  50,  50]]])

My Personal Notes arrow_drop_up
Last Updated : 04 Jan, 2022
Like Article
Save Article
Similar Reads
Related Tutorials