Skip to content
Related Articles

Related Articles

Improve Article
Save Article
Like Article

Tensorflow.js tf.loadGraphModel() Function

  • Last Updated : 31 May, 2021

Tensorflow.js is an open-source library developed by Google for running machine learning models as well as deep learning neural networks in the browser or node environment.

The .loadGraphModel() function is used to Load a graph model given a URL to the model definition.

Hey geek! The constant emerging technologies in the world of web development always keeps the excitement for this subject through the roof. But before you tackle the big projects, we suggest you start by learning the basics. Kickstart your web development journey by learning JS concepts with our JavaScript Course. Now at it's lowest price ever!

Syntax:

tf.loadGraphModel (modelUrl, options)

Parameters:



  • modelUrl: The first tensor input that can be of type string or io.IOHandler. This parameter is the URL or an io.IOHandler that helps to load the models.
  • options: The second tensor input that is optional. Options are for the HTTP request, which allows sending credentials and custom headers. The types of options are –
    • requestInit: RequestInit are for HTTP requests.
    • onProgress: OnProgress is for progress callback.
    • fetchFunc: It is a function used to override the window.fetch function.
    • strict: Strict is a loading model: whether it is an extraneous weight or missing weights should trigger an Error.
    • weightPathPrefix: Path prefix is for weight files which is by default that is calculated from the path of the model JSON file.
    • fromTFHub: It is a boolean which is whether the module or model is to be loaded from TF Hub.

Return value: It returns the Promise <tf.GraphModel>.

Example 1: In this example, we are loading MobileNetV2 from a URL and making a prediction with a zeros input.

Javascript




// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Defining tensor input elements 
const modelUrl =
  
// Calling the loadGraphModel () method
const model = await tf.loadGraphModel(modelUrl);
  
// Printing the zeroes
const zeros = tf.zeros([1, 224, 224, 3]);
model.predict(zeros).print();

Output:

Tensor
     [[-0.1412081, -0.5656458, 0.7578365, ..., 
     -1.0148169, -0.81284, 1.1898142],]

Example 2: In this example, we are loading MobileNetV2 from a TF Hub URL and making a prediction with a zeros input.

Javascript




// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Defining tensor input elements 
const modelUrl =
  
// Calling the loadGraphModel () method
const model = await tf.loadGraphModel(
        modelUrl, {fromTFHub: true});
  
// Printing the zeores
const zeros = tf.zeros([1, 224, 224, 3]);
model.predict(zeros).print();

Output:

Tensor
     [[-1.0764486, 0.0097444, 1.1630495, ..., 
     -0.345558, 0.035432, 0.9112286],]

Reference: https://js.tensorflow.org/api/1.0.0/#loadGraphModel




My Personal Notes arrow_drop_up
Recommended Articles
Page :