How To Sort The Elements of a Tensor in PyTorch?
In this article, we are going to see how to sort the elements of a PyTorch Tensor in Python.
To sort the elements of a PyTorch tensor, we use torch.sort() method. We can sort the elements along with columns or rows when the tensor is 2-dimensional.
Syntax: torch.sort(input, dim=- 1, descending=False)
- input: It is an input PyTorch tensor.
- dim: The dimension along which the tensor is sorted. It is an optional int value.
- descending: An optional boolean value used for sorting tensor elements in ascending or descending order. Default is set to False, sorting in ascending order.
Returns: It returns a named tuple of (values, indices), where values are the sorted values and indices are the indices of the elements in the original input tensor.
In the below example, we sort the elements of a 1-dimensional tensor in ascending and descending orders. Sort tensor in ascending or descending order. We apply the torch.sort() method to sort the elements of an input tensor. To sort in descending order pass descending=True to the method.
In this example, we sort the elements of a 2-dimensional tensor in ascending as well as descending orders along with the columns.
In this example, we sort the elements of a 2-dimensional tensor in ascending as well as descending orders along the rows.
Please Login to comment...