Skip to content
Related Articles
Get the best out of our app
GeeksforGeeks App
Open App
geeksforgeeks
Browser
Continue

Related Articles

Tensorflow.js tf.loadGraphModel() Function

Improve Article
Save Article
Like Article
Improve Article
Save Article
Like Article

Tensorflow.js is an open-source library developed by Google for running machine learning models as well as deep learning neural networks in the browser or node environment.

The .loadGraphModel() function is used to Load a graph model given a URL to the model definition.

Syntax:

tf.loadGraphModel (modelUrl, options)

Parameters:

  • modelUrl: The first tensor input that can be of type string or io.IOHandler. This parameter is the URL or an io.IOHandler that helps to load the models.
  • options: The second tensor input that is optional. Options are for the HTTP request, which allows sending credentials and custom headers. The types of options are –
    • requestInit: RequestInit are for HTTP requests.
    • onProgress: OnProgress is for progress callback.
    • fetchFunc: It is a function used to override the window.fetch function.
    • strict: Strict is a loading model: whether it is an extraneous weight or missing weights should trigger an Error.
    • weightPathPrefix: Path prefix is for weight files which is by default that is calculated from the path of the model JSON file.
    • fromTFHub: It is a boolean which is whether the module or model is to be loaded from TF Hub.

Return value: It returns the Promise <tf.GraphModel>.

Example 1: In this example, we are loading MobileNetV2 from a URL and making a prediction with a zeros input.

Javascript




// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Defining tensor input elements 
const modelUrl =
  
// Calling the loadGraphModel () method
const model = await tf.loadGraphModel(modelUrl);
  
// Printing the zeroes
const zeros = tf.zeros([1, 224, 224, 3]);
model.predict(zeros).print();

Output:

Tensor
     [[-0.1412081, -0.5656458, 0.7578365, ..., 
     -1.0148169, -0.81284, 1.1898142],]

Example 2: In this example, we are loading MobileNetV2 from a TF Hub URL and making a prediction with a zeros input.

Javascript




// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Defining tensor input elements 
const modelUrl =
  
// Calling the loadGraphModel () method
const model = await tf.loadGraphModel(
        modelUrl, {fromTFHub: true});
  
// Printing the zeores
const zeros = tf.zeros([1, 224, 224, 3]);
model.predict(zeros).print();

Output:

Tensor
     [[-1.0764486, 0.0097444, 1.1630495, ..., 
     -0.345558, 0.035432, 0.9112286],]

Reference: https://js.tensorflow.org/api/1.0.0/#loadGraphModel


My Personal Notes arrow_drop_up
Last Updated : 31 May, 2021
Like Article
Save Article
Similar Reads
Related Tutorials