Open In App

Tensorflow.js tf.GraphModel class .predict() Method

Tensorflow.js is an open-source library that is developed by Google for running machine learning models as well as deep learning neural networks in the browser or node environment.

The .predict() function is used to implement the implication in favor of input tensors.


predict(inputs, config?)


Return Value: It returns tf.Tensor|tf.Tensor[]|{[name: string]: tf.Tensor}.

Example 1: In this example, we are loading MobileNetV2 from a URL and holding a prediction with a zeros input.

// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Defining tensor input elements
const model_Url =
// Calling the loadGraphModel() method
const mymodel = await tf.loadGraphModel(model_Url);
// Defining inputs
const inputs = tf.zeros([1, 224, 224, 3]);
// Calling predict() method and 
// Printing output


     [[-0.1800361, -0.4059965, 0.8190175, 
     -0.8953396, -1.0841646, 1.2912753],]

Example 2: In this example, we are loading MobileNetV2 from a TF Hub URL and holding a prediction with a zeros input.

// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
// Defining tensor input elements
const model_Url =
// Calling the loadGraphModel() method
const model = await tf.loadGraphModel(
        model_Url, {fromTFHub: true});
// Defining inputs
const inputs = tf.zeros([1, 224, 224, 3]);
// Defining batchsize
const batchsize = 1;
// Defining verbose
const verbose = true;
// Calling predict() method and
// Printing output
model.predict(inputs, batchsize, verbose).print();


     [[-1.1690605, 0.0195426, 1.1962479, 
     -0.4825858, -0.0055641, 1.1937635],]


Article Tags :