Open In App

Tensorflow.js tf.Sequential class .fitDataset() Method

Tensorflow.js is an open-source library developed by Google for running machine learning models and deep learning neural networks in the browser or node environment. Tensorflow.js tf.Sequential class .fitDataset() method is used to trains the model using a dataset object.

Syntax:



model.fitDataset(dataset, args);

Parameters: This method contains the following parameters:

Returns: Promise< History >



 

Example 1: In this example, we will train our model using a dataset of array.




import * as tf from "@tensorflow/tfjs"
  
// Creating model
const gfg_Model = tf.sequential() ;
  
// Adding layer to model
const config = {units: 1, inputShape: [2]}
const gfg_layer = tf.layers.dense(config);
gfg_Model.add(gfg_layer);
  
// Compiling the model
const config2 = {optimizer: 'sgd', loss: 'meanSquaredError'
gfg_Model.compile(config2);
  
// Creating Datasets for training
const array1 = [[1,2], [1,4], [1,3], [3,4]];
const array2 = [1, 1];
const arrData1 = tf.data.array(array1);
const arrData2 = tf.data.array(array2);
  
const config3 = {xs:arrData1, ys:arrData2}
const arrayDataset = tf.data.zip(config3)
const ArrayDataset = arrayDataset.batch(3).shuffle(6);
  
// Training the model
const Tm = await gfg_Model.fitDataset(ArrayDataset, { epochs: 3 });
  
// Printing the loss after training
console.log("Loss " + " : " + Tm.history.loss[0]);

Output:

Loss : 0.428712397813797

Example 2: In this example, we will train our model with dataset which is made with csv file.




import * as tf from "@tensorflow/tfjs";
  
// Path for the CSV file
const gfg_CsvFile =
  
// Creating model
const gfg_Model = tf.sequential();
  
// Adding layer to model
const config = { units: 1, inputShape: [12] };
const gfg_layer = tf.layers.dense(config);
gfg_Model.add(gfg_layer);
  
// Compiling the model
const opt = tf.train.sgd(0.0001);
gfg_Model.compile({ optimizer: opt, loss: "meanSquaredError" });
  
// Here we want to predict column tax
const config2 = { columnConfigs: { tax: { isLabel: true } } };
const csvDataset = tf.data.csv(gfg_CsvFile, config2);
  
// Creating dataset for training
const flattenedDataset = csvDataset
  .map(({ xs, ys }) => {
    return { xs: Object.values(xs), ys: Object.values(ys) };
  })
  .batch(5);
  
// Training the model
const Tm = await gfg_Model.fitDataset(flattenedDataset, { epochs: 5 });
  
for (let i = 0; i < 5; i++) {
  console.log(Tm.history.loss[i]);
}

Output:

21489.68359375
8750.29296875
6632.365234375
5908.6171875
5546.45654296875

Reference: https://js.tensorflow.org/api/latest/#tf.Sequential.fitDataset


Article Tags :