Open In App

Tensorflow.js tf.LayersModel class .fit() Method

Tensorflow.js is an open-source library developed by Google for running machine learning models and deep learning neural networks in the browser or node environment.

The tf.LayersModel class .fit( ) method is used to train the model for the fixed number of epochs (iterations on a dataset).



Syntax:

fit(x, y, args?)

Parameters: This method accepts the following parameters.



Returns: It returns the promise of history.

Example 1:




// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Defining model
const mymodel = tf.sequential({
     layers: [tf.layers.dense({units: 2, inputShape: [6]})]
});
  
// Compiling the above model
mymodel.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
  
// Using for loop
for (let i = 0; i < 4; i++) {
     
  // Calling fit() method
  const his = await mymodel.fit(tf.zeros([6, 6]), tf.ones([6, 2]), {
       batchSize: 5,
       epochs: 4
   });
    
    // Printing output
    console.log(his.history.loss[1]);
}

Output:

0.9574100375175476
0.8151942491531372
0.694103479385376
0.5909997820854187

Example 2:




// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Defining model
const mymodel = tf.sequential({
     layers: [tf.layers.dense({units: 2, inputShape: [6], 
                               activation : "sigmoid"})]});
  
// Compiling the above model
mymodel.compile({optimizer: 'sgd', loss: 'meanSquaredError'});   
  
// Calling fit() method
 const his = await mymodel.fit(tf.truncatedNormal([6, 6]), 
                             tf.randomNormal([6, 2]), { batchSize: 5,
                             epochs: 4, validationSplit: 0.2, 
                             shuffle: true, initialEpoch: 2, 
                             stepsPerEpoch: 1, validationSteps: 2});
    
// Printing output
console.log(JSON.stringify(his.history));

Output:

{"val_loss":[0.35800713300704956,0.35819053649902344],
"loss":[0.633269190788269,0.632409930229187]}

Reference: https://js.tensorflow.org/api/latest/#tf.LayersModel.fit


Article Tags :