Skip to content
Related Articles

Related Articles

Improve Article
Save Article
Like Article

Tensorflow.js tf.Sequential class .fitDataset() Method

  • Last Updated : 01 Sep, 2021

Tensorflow.js is an open-source library developed by Google for running machine learning models and deep learning neural networks in the browser or node environment. Tensorflow.js tf.Sequential class .fitDataset() method is used to trains the model using a dataset object.


model.fitDataset(dataset, args);

Parameters: This method contains the following parameters:

  • dataset: It is a dataset of an input value. It can be a dataset of primitives, an array, or an object.
  • args: It containing the following values:
    • epochs : It is the total number of passes in the training dataset during the training model. It is an integer value.
    • batchesPerEpoch : It defines the number of batches in each epoch. It’s value depends on the batch size as batch size increases it’s size decreases.
    • verbose: It help in showing the progress for each epoch. If the value is 0 – It means no printed message during fit() call. If the value is 1 – It means in Node.js, it prints the progress bar. In the browser it shows no action. Value 1 is the default value . 2 – Value 2 is not implemented yet.
    • callbacks: It defines a list of callbacks to be call during training. Variable can have one or more of these callbacks onTrainBegin( ), onTrainEnd( ), onEpochBegin( ), onEpochEnd( ), onBatchBegin( ), onBatchEnd( ), onYield( ).
    • validationData: It is used to give an estimate of the final model when selecting between final models. This could be any of these: An Array of [ xVal, yVal ], a Dataset object with elements of the form { xs : xVal, ys : yVal }.
    • validationBatchSize: It is the number that defines the size of the batch. It is used to validate the batch size. It means that we can’t put all datasets at once that exceeding this value. its defaults value is 32.
    • validationBatches: It is used to validate the batches of samples. It is used to draw validation data for validation purposes at every end of an epoch.
    • classWeight: It is used for weighting the loss function. It can be useful to tell the model to pay more attention to samples from an under-represented class.
    • initialEpoch: It is used to define the epoch’s value at which to start training. It is useful for resuming a previous training run.
    • yieldEvery: It defines the configuration of the frequency of yielding the main thread to other tasks. It can be auto, It means the yielding happens at a certain frame rate. batch, If the value is this, It yields every batch. epoch, If the value is this, It yields every epoch. any number, If the value is any number, it yields every number milliseconds. never, If the value is this, it never yields.

Returns: Promise< History >


Example 1: In this example, we will train our model using a dataset of array.


import * as tf from "@tensorflow/tfjs"
// Creating model
const gfg_Model = tf.sequential() ;
// Adding layer to model
const config = {units: 1, inputShape: [2]}
const gfg_layer = tf.layers.dense(config);
// Compiling the model
const config2 = {optimizer: 'sgd', loss: 'meanSquaredError'
// Creating Datasets for training
const array1 = [[1,2], [1,4], [1,3], [3,4]];
const array2 = [1, 1];
const arrData1 =;
const arrData2 =;
const config3 = {xs:arrData1, ys:arrData2}
const arrayDataset =
const ArrayDataset = arrayDataset.batch(3).shuffle(6);
// Training the model
const Tm = await gfg_Model.fitDataset(ArrayDataset, { epochs: 3 });
// Printing the loss after training
console.log("Loss " + " : " + Tm.history.loss[0]);


Loss : 0.428712397813797

Example 2: In this example, we will train our model with dataset which is made with csv file.


import * as tf from "@tensorflow/tfjs";
// Path for the CSV file
const gfg_CsvFile =
// Creating model
const gfg_Model = tf.sequential();
// Adding layer to model
const config = { units: 1, inputShape: [12] };
const gfg_layer = tf.layers.dense(config);
// Compiling the model
const opt = tf.train.sgd(0.0001);
gfg_Model.compile({ optimizer: opt, loss: "meanSquaredError" });
// Here we want to predict column tax
const config2 = { columnConfigs: { tax: { isLabel: true } } };
const csvDataset =, config2);
// Creating dataset for training
const flattenedDataset = csvDataset
  .map(({ xs, ys }) => {
    return { xs: Object.values(xs), ys: Object.values(ys) };
// Training the model
const Tm = await gfg_Model.fitDataset(flattenedDataset, { epochs: 5 });
for (let i = 0; i < 5; i++) {




Hey geek! The constant emerging technologies in the world of web development always keeps the excitement for this subject through the roof. But before you tackle the big projects, we suggest you start by learning the basics. Kickstart your web development journey by learning JS concepts with our JavaScript Course. Now at it's lowest price ever!

My Personal Notes arrow_drop_up
Recommended Articles
Page :

Start Your Coding Journey Now!