Python | tensorflow.math.argmax() method
TensorFlow is open-source python library designed by Google to develop Machine Learning models and deep learning neural networks. argmax() is a method present in tensorflow math module. This method is used to find the maximum value across the axes.
Syntax:
tensorflow.math.argmax(
input,axes,output_type,name
)
Arguments:
1. input: It is a tensor. Allowed dtypes for this tensor are float32,
float64, int32, uint8, int16, int8, complex64, int64, qint8,
quint8, qint32, bfloat16, uint16, complex128, half, uint32, uint64.
2. axes: It is also a vector. It describes the axes to reduce the tensor.
Allowed dtype are int32 and int64. Also [-rank(input),rank(input)) is the range allowed.
axes=0 is used for vector.
3. output_type: It defines the dtype in which returned result should be.
Allowed values are int32, int64 and the default value is int64.
4. name: It is an optional argument which defines name for the operation.
Return:
A tensor of output_type which contains the indices of the maximum value along the axes.
Example 1:
Python3
import tensorflow as tf
a = tf.constant([ 5 , 10 , 5.6 , 7.9 , 1 , 50 ])
b = tf.math.argmax( input = a)
print ( 'tensor: ' ,b)
c = tf.keras.backend. eval (b)
print ( 'value: ' ,c)
|
Output:
tensor: tf.Tensor(5, shape=(), dtype=int64)
value: 5
Example 2:
This example uses a tensor of shape(3,3).
Python3
import tensorflow as tf
a = tf.constant(value = [ 9 , 8 , 7 , 3 , 5 , 4 , 6 , 2 , 1 ],shape = ( 3 , 3 ))
print (a)
b = tf.math.argmax( input = a)
print ( 'Indices Tensor: ' ,b)
c = tf.keras.backend. eval (b)
print ( 'Indices: ' ,c)
|
Output:
tf.Tensor(
[[9 8 7]
[3 5 4]
[6 2 1]], shape=(3, 3), dtype=int32)
Indices tensor: tf.Tensor([0 0 0], shape=(3,), dtype=int64)
Indices: [0 0 0]
# maximum value along the axes are 9,8,7 at indices 0,0,0 respectively.
Last Updated :
01 Oct, 2021
Like Article
Save Article
Share your thoughts in the comments
Please Login to comment...