Skip to content
Related Articles

Related Articles

Python | tensorflow.math.argmax() method

View Discussion
Improve Article
Save Article
  • Last Updated : 01 Oct, 2021

TensorFlow is open-source python library designed by Google to develop Machine Learning models  and deep learning  neural networks. argmax() is a method present in tensorflow math module. This method is used to find the maximum value across the axes.


1. input: It is a tensor. Allowed dtypes for this tensor are float32,
          float64, int32, uint8, int16, int8, complex64, int64, qint8,
          quint8, qint32, bfloat16, uint16, complex128, half, uint32, uint64. 
2. axes: It is also a vector. It describes the axes to reduce the tensor.
         Allowed dtype are int32 and int64. Also [-rank(input),rank(input)) is the range allowed.
         axes=0 is used for vector.
3. output_type: It defines the dtype in which returned result should be.
                Allowed values are int32, int64 and the default value is int64.
4. name: It is an optional argument which defines name for the operation.
A tensor of output_type which contains the indices of the maximum value along the axes. 

Example 1:


# importing the library
import tensorflow as tf
# initializing the constant tensor
a = tf.constant([5,10,5.6,7.9,1,50]) # 50 is the maximum value at index 5
# getting the maximum value index tensor
b = tf.math.argmax(input = a)
# printing the tensor
print('tensor: ',b)
# Evaluating the value of tensor
c = tf.keras.backend.eval(b)
#printing the value
print('value: ',c)


tensor:  tf.Tensor(5, shape=(), dtype=int64)
value: 5

Example 2: 

This example uses a tensor of shape(3,3).


# importing the library
import tensorflow as tf
# initializing the constant tensor
a = tf.constant(value = [9,8,7,3,5,4,6,2,1],shape = (3,3))
# printing the initialized tensor
# getting the maximum value indices tensor
b = tf.math.argmax(input = a)
# printing the tensor
print('Indices Tensor: ',b)
# Evaluating the tensor value
c = tf.keras.backend.eval(b)
# printing the value
print('Indices: ',c)


[[9 8 7]
 [3 5 4]
 [6 2 1]], shape=(3, 3), dtype=int32)
 Indices tensor: tf.Tensor([0 0 0], shape=(3,), dtype=int64)
 Indices: [0 0 0]
 # maximum value along the axes are 9,8,7 at indices 0,0,0 respectively.

My Personal Notes arrow_drop_up
Recommended Articles
Page :

Start Your Coding Journey Now!