Skip to content
Related Articles

Related Articles

Improve Article

Tensorflow.js tf.LayersModel class .trainOnBatch() Method

  • Last Updated : 01 Sep, 2021

Tensorflow.js is an open-source library that is developed by Google for running machine learning models as well as deep learning neural networks in the browser or node environment.

The .trainOnBatch() function is used to run a separate gradient update on a particular batch of data.

Hey geek! The constant emerging technologies in the world of web development always keeps the excitement for this subject through the roof. But before you tackle the big projects, we suggest you start by learning the basics. Kickstart your web development journey by learning JS concepts with our JavaScript Course. Now at it's lowest price ever!

Note: This method varies from fit() as well as fitDataset() in the following ways:

  • This method works on absolutely one batch of data.
  • This method simply returns the loss and metric values, in place of returning the batch by batch loss as well as metric values.
  • This method doesn’t favor fine grained options like verbosity and callbacks.

Syntax:



trainOnBatch(x, y)

 

Parameters:

  • x: The stated input data. It can be of type tf.Tensor, tf.Tensor[], or {[inputName: string]: tf.Tensor}. It can be any one of the following:
    1. A stated tf.Tensor, or else an array of tf.Tensors if the stated model possesses multiple inputs.
    2. An Object plotting input names to the matching tf.Tensor in case the stated model possesses named inputs.
  • y: The stated Target data. It can be of type tf.Tensor, tf.Tensor[], or {[inputName: string]: tf.Tensor}. It must be constant with respect to x.

Return Value: It returns promise of number or number[].

Example 1:

Javascript




// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Training Model
const mymodel = tf.sequential(
     {layers: [tf.layers.dense({units: 2, inputShape: [2]})]});
  
// Compiling our model
const config = {optimizer:'sgd',
            loss:'meanSquaredError'};
mymodel.compile(config);
      
// Test tensor and target tensor
const xs = tf.ones([3,2]);
const ys = tf.ones([3,2]);
      
// Calling trainOneBatch() method
const result = await mymodel.trainOnBatch(xs, ys);
  
// Printing output
console.log(result);

Output:

2.0696773529052734

Example 2:

Javascript




// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
async function run() {
  
  // Training Model
  const mymodel = tf.sequential(
     {layers: [tf.layers.dense({units: 2, inputShape: [2], 
                                activation: 'sigmoid'})]});
  
  // Compiling our model
  const config = {optimizer:'sgd',
            loss:'meanSquaredError'};
  mymodel.compile(config);
      
  // Test tensor and target tensor
  const xs = tf.truncatedNormal([3,2]);
  const ys = tf.randomNormal([3,2]);
      
  // Calling trainOneBatch() method
  const result = await mymodel.trainOnBatch(xs, ys);
  
  // Printing output
  console.log(JSON.stringify(+result));
}
    
// Function call
await run();

Output:

0.5935208797454834

Reference: https://js.tensorflow.org/api/latest/#tf.LayersModel.trainOnBatch




My Personal Notes arrow_drop_up
Recommended Articles
Page :