Open In App

Tensor.detach() Method in Python PyTorch

Last Updated : 10 Jun, 2022
Improve
Improve
Like Article
Like
Save
Share
Report

In this article, we will see Tensor.detach() method in PyTorch using Python.

Pytorch is a Python and C++ interface for an open-source deep learning platform. It is found within the torch module. In PyTorch, the input data has to be processed in the form of a tensor. It also includes a module that calculates gradients automatically for backpropagation. Tensor.detach() method in PyTorch is used to separate a tensor from the computational graph by returning a new tensor that doesn’t require a gradient. If we want to move a tensor from the Graphical Processing Unit (GPU) to the Central Processing Unit (CPU), then we can use detach() method. It will not take any parameter and return the detached tensor.

Syntax: tensor.detach()

Return: the detached tensor

Example 1:

In this example, we will create a one-dimensional tensor with a gradient parameter and detach it using a tensor.detach() method. requires_grad takes boolean value – True

Python3




# import the  torch module
import torch
  
# create one dimensional tensor with 5 elements with requires_grad
# parameter that sets to True
tensor1 = torch.tensor([7.8, 3.2, 4.4, 4.3, 3.3], requires_grad=True)
print(tensor1)
  
# detach the tensor
print(tensor1.detach())


Output:

tensor([7.8000, 3.2000, 4.4000, 4.3000, 3.3000], requires_grad=True)

tensor([7.8000, 3.2000, 4.4000, 4.3000, 3.3000])

Example 2:

In this example, we will create a two-dimensional tensor with a gradient parameter= False , you will notice that in the output, the tensor doesn’t effect if we set requires_grad = False, and detach it using a tensor.detach() method.

Python3




# import the  torch module
import torch
  
# create two dimensional tensor with 5 elements with
# requires_grad parameter that sets to True
tensor1 = torch.tensor([[7.8, 3.2, 4.4, 4.3, 3.3],
                        [3., 6., 7., 3., 2.]], requires_grad=False)
print(tensor1)
  
# detach the tensor
print(tensor1.detach())


Output:

tensor([[7.8000, 3.2000, 4.4000, 4.3000, 3.3000],
        [3.0000, 6.0000, 7.0000, 3.0000, 2.0000]])
        
tensor([[7.8000, 3.2000, 4.4000, 4.3000, 3.3000],
        [3.0000, 6.0000, 7.0000, 3.0000, 2.0000]])


Like Article
Suggest improvement
Previous
Next
Share your thoughts in the comments

Similar Reads