Open In App

Tensorflow.js tf.Sequential class .trainOnBatch() Method

Tensorflow.js is an open-source library that is developed by Google for running machine learning models as well as deep learning neural networks in the browser or node environment.

The .trainOnBatch() function is used to run a separate gradient update on a particular batch of data.



Note:

This method varies from fit() as well as fitDataset() in the following ways:



Syntax:

trainOnBatch(x, y)

Parameters:

Return Value: It returns the promise of number or number[].

Example 1:




// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
 
// Training Model
const gfg = tf.sequential();
   
// Adding layer to model 
const layer = tf.layers.dense({units:3,
               inputShape : [5]});
   gfg.add(layer);
     
// Compiling our model
const config = {optimizer:'sgd',
              loss:'meanSquaredError'};
  gfg.compile(config);
 
// Test tensor and target tensor
const xs = tf.ones([3, 5]);
const ys = tf.ones([3, 3]);
     
// Calling trainOnBatch() method
const result = await gfg.trainOnBatch(xs, ys);
 
// Printing output
console.log(result);

Output:

0.3589147925376892

Example 2:




// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
 
async function run() {
 
  // Training Model
  const gfg = tf.sequential();
   
  // Adding layer to model 
  const layer = tf.layers.dense({units:2,
               inputShape : [2]});
  gfg.add(layer);
     
  // Compiling our model
  const config = {optimizer:'sgd',
              loss:'meanSquaredError'};
  gfg.compile(config);
 
  // Test tensor and target tensor
  const xs = tf.truncatedNormal([3, 2]);
  const ys = tf.randomNormal([3, 2]);
     
  // Calling trainOnBatch() method
  const result = await gfg.trainOnBatch(xs, ys);
 
  // Printing output
  console.log(JSON.stringify(+result));
}
   
// Function call
await run();

Output:

1.6889342069625854

Reference: https://js.tensorflow.org/api/latest/#tf.Sequential.trainOnBatch


Article Tags :