Open In App

Python PyTorch stack() method

Last Updated : 28 Feb, 2022
Improve
Improve
Like Article
Like
Save
Share
Report

PyTorch torch.stack() method joins (concatenates) a sequence of tensors (two or more tensors) along a new dimension. It inserts new dimension and concatenates the tensors along that dimension. This method joins the tensors with the same dimensions and shape. We could also use torch.cat() to join tensors But here we discuss the torch.stack() method. 

Syntax: torch.stack(tensors, dim=0)

Arguments:

  • tensors: It’s a sequence of tensors of same shape and dimensions
  • dim: It’s the dimension to insert. It’s an integer between 0 and the number of dimensions of input tensors. 

Returns: It returns the concatenated tensor along a new dimension.

 Let’s understand the torch.stack() method with the help of some Python 3 examples.

Example 1: 

In the Python example below we join two one-dimensional tensors using torch.stack() method.

Python3




# Python 3 program to demonstrate torch.stack() method
# for two one dimensional tensors
# importing torch
import torch
  
# creating tensors
x = torch.tensor([1.,3.,6.,10.])
y = torch.tensor([2.,7.,9.,13.])
  
# printing above created tensors
print("Tensor x:", x)
print("Tensor y:", y)
  
# join above tensor using "torch.stack()"
print("join tensors:")
t = torch.stack((x,y))
  
# print final tensor after join
print(t)
  
print("join tensors dimension 0:")
t = torch.stack((x,y), dim = 0)
print(t)
  
print("join tensors dimension 1:")
t = torch.stack((x,y), dim = 1)
print(t)


Output:

Tensor x: tensor([ 1.,  3.,  6., 10.])
Tensor y: tensor([ 2.,  7.,  9., 13.])
join tensors:
tensor([[ 1.,  3.,  6., 10.],
        [ 2.,  7.,  9., 13.]])
join tensors dimension 0:
tensor([[ 1.,  3.,  6., 10.],
        [ 2.,  7.,  9., 13.]])
join tensors dimension 1:
tensor([[ 1.,  2.],
        [ 3.,  7.],
        [ 6.,  9.],
        [10., 13.]])

Explanation: In the above code tensors x and y are one-dimensional each having four elements.  The final concatenated tensor is a 2D tensor. As the dimension is 1, we can stack the tensors with dimensions 0 and 1. When dim =0 the tensors are stacked increasing the number of rows. When dim =1 the tensors are transposed and stacked along the column. 

Example 2:

In the Python example below we join two one-dimensional tensors using torch.stack() method.

Python3




# Python 3 program to demonstrate torch.stack() method
# for two 2D tensors.
# importing torch
import torch
  
# creating tensors
x = torch.tensor([[1., 3., 6.], [10., 13., 20.]])
y = torch.tensor([[2., 7., 9.], [14., 21., 34.]])
  
# printing above created tensors
print("Tensor x:\n", x)
print("Tensor y:\n", y)
  
# join above tensor using "torch.stack()"
print("join tensors")
t = torch.stack((x, y))
  
# print final tensor after join
print(t)
  
print("join tensors in dimension 0:")
t = torch.stack((x, y), 0)
print(t)
  
print("join tensors in dimension 1:")
t = torch.stack((x, y), 1)
print(t)
print("join tensors in dimension 2:")
t = torch.stack((x, y), 2)
print(t)


Output:

Tensor x:
 tensor([[ 1.,  3.,  6.],
        [10., 13., 20.]])
Tensor y:
 tensor([[ 2.,  7.,  9.],
        [14., 21., 34.]])
join tensors
tensor([[[ 1.,  3.,  6.],
         [10., 13., 20.]],

        [[ 2.,  7.,  9.],
         [14., 21., 34.]]])
join tensors in dimension 0:
tensor([[[ 1.,  3.,  6.],
         [10., 13., 20.]],

        [[ 2.,  7.,  9.],
         [14., 21., 34.]]])
join tensors in dimension 1:
tensor([[[ 1.,  3.,  6.],
         [ 2.,  7.,  9.]],

        [[10., 13., 20.],
         [14., 21., 34.]]])
join tensors in dimension 2:
tensor([[[ 1.,  2.],
         [ 3.,  7.],
         [ 6.,  9.]],

        [[10., 14.],
         [13., 21.],
         [20., 34.]]])

Explanation: In the above code, x and y are two-dimensional tensors. Notice that the final tensor is a 3-D tensor. As the dimension of each input tensor is 2, we can stack the tensors with dimensions 0 and 2. See the differences among the final output tensors with dim = 0, 1, and 2.

Example 3:

In this example, we join more than two tensors. We can join any number of tensors.

Python3




# Python 3 program to demonstrate torch.stack() method
# for three one-dimensional tensors
# importing torch
import torch
  
# creating tensors
x = torch.tensor([1., 3., 6., 10.])
y = torch.tensor([2., 7., 9., 13.])
z = torch.tensor([4., 5., 8., 11.])
  
# printing above created tensors
print("Tensor x:", x)
print("Tensor y:", y)
print("Tensor z:", z)
  
# join above tensor using "torch.stack()"
print("join tensors:")
t = torch.stack((x, y, z))
# print final tensor after join
print(t)
  
print("join tensors dimension 0:")
t = torch.stack((x, y, z), dim=0)
print(t)
  
print("join tensors dimension 1:")
t = torch.stack((x, y, z), dim=1)
print(t)


Output:

Tensor x: tensor([ 1.,  3.,  6., 10.])
Tensor y: tensor([ 2.,  7.,  9., 13.])
Tensor z: tensor([ 4.,  5.,  8., 11.])
join tensors:
tensor([[ 1.,  3.,  6., 10.],
        [ 2.,  7.,  9., 13.],
        [ 4.,  5.,  8., 11.]])
join tensors dimension 0:
tensor([[ 1.,  3.,  6., 10.],
        [ 2.,  7.,  9., 13.],
        [ 4.,  5.,  8., 11.]])
join tensors dimension 1:
tensor([[ 1.,  2.,  4.],
        [ 3.,  7.,  5.],
        [ 6.,  9.,  8.],
        [10., 13., 11.]])

Example 4: Demonstrating Errors

In the example below we show errors when the input tensors are not of the same shape. 

Python3




# Python 3 program to demonstrate torch.stack() method
# for one-dimensional tensors
# importing torch
import torch
  
# creating tensors
x = torch.tensor([1., 3., 6., 10.])
y = torch.tensor([2., 7., 9.])
  
# printing above created tensors
print("Tensor x:", x)
print("Tensor y:", y)
  
# join above tensor using "torch.stack()"
print("join tensors:")
t = torch.stack((x, y))
# print final tensor after join
print(t)
  
print("join tensors dimension 0:")
t = torch.stack((x, y), dim=0)
print(t)
  
print("join tensors dimension 1:")
t = torch.stack((x, y), dim=1)
print(t)


Output:

Shape of x: torch.Size([4])

Shape of y: torch.Size([3])

RuntimeError: stack expects each tensor to be equal size, but got [4] at entry 0 and [3] at entry 1

Notice that the shape of the two tensors is not the same. It throws a runtime error. In the same way, when the dimension of tensors is not the same it throws a runtime error. Try for yourself for tensors with different dimensions and see how the output is.



Like Article
Suggest improvement
Previous
Next
Share your thoughts in the comments

Similar Reads