How to find the k-th and the top “k” elements of a tensor in PyTorch?
In this article, we are going to see how to find the kth and the top ‘k’ elements of a tensor.
So we can find the kth element of the tensor by using torch.kthvalue() and we can find the top ‘k’ elements of a tensor by using torch.topk() methods.
- torch.kthvalue() function: First this function sorts the tensor in ascending order and then returns the kth element of the sorted tensor and the index of the kth element from the original tensor.
Syntax: torch.kthvalue(input_tensor, k, dim=None, keepdim=False, out=None)
- Input_tensor: tensor.
- k: k is integer and it’s for k-th smallest element of tensor.
- dim: dim is for dimension to find the k-th value along of tensor.
- keepdim (bool): keepdim is for whether the output tensor has dim retained or not.
Return: This method returns a tuple (values, indices) of the k-th element of tensor.
- torch.topk() function: This function helps us to find the top ‘k’ elements of a given tensor. it will return top ‘k’ elements of the tensor and it will also return indexes of top ‘k’ elements in the original tensor.
Syntax: torch.topk(input_tensor, k, dim=None, largest=True, sorted=True, out=None)
- input_tensor: tensor.
- k: k is integer value and it’s for the k in top-k.
- dim: the dim is for the dimension to sort along of tensor.
- largest: this is used to controls whether return largest or smallest elements of tensor.
- sorted: it controls whether to return the elements in sorted order.
Return: this function is returns the ‘k’ largest elements of tensor along a given dimension.
Example 1: The following program is to find the k-th element of a tensor.
Example 2: The following program is to find the top k elements of tensor
Please Login to comment...