Open In App

Tensorflow.js tf.cast() Function

Tensorflow.js is an open-source library developed by Google for running machine learning models and deep learning neural networks in the browser or node environment.

The tf.cast() function is used to cast a specified Tensor to a new data type.



Syntax:

tf.cast (x, dtype)

Parameters: This function accepts two parameters which are illustrated below:



Return Value: It returns a casted tensor of new data type.

Example 1:




// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Initializing a tensor of some values
const x = tf.tensor1d([2.3, 1.7, 5, 0, 1, 0.5]);
  
// Calling the .cast() function over the 
// above tensor to cast in "int32" data type
tf.cast(x, 'int32').print();

Output:

 Tensor
   [2, 1, 5, 0, 1, 0]

Example 2:




// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Using a tensor of some values
// as the parameter for .cast() function to
// cast into bool data type
tf.cast(tf.tensor1d([0, 1, -3]), 'bool').print();

Output:

Tensor
   [false, true, true]

Reference: https://js.tensorflow.org/api/latest/#cast

Article Tags :