Skip to content
Related Articles

Related Articles

Improve Article

Tensorflow.js tf.data.Dataset class .batch() Method

  • Last Updated : 31 May, 2021

Tensorflow.js is an open-source library developed by Google for running machine learning models and deep learning neural networks in the browser or node environment. It also helps the developers to develop ML models in JavaScript language and can use ML directly in the browser or in Node.js.

The tf.data.Dataset.batch() function is used to group the elements into batches.

Syntax:

tf.data.Dataset.batch(batchSize, smallLastBatch?)

Parameters:

  • batchSize: elements that should there in a single batch. 
  • smallLastBatch: if true, the final batch will emit elements if it has lesser elements than the batchSize else vice versa. Default value is true. It is optional to provide this value.

Return value: It returns a tf.data.Dataset.



Example 1: In this example we will take an array if size 6 and split it into batches with 3 elements with each.

Javascript




// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Creating an array
const gfg = tf.data.array(
  [10, 20, 30, 40, 50, 60]
).batch(3);
  
// Printing the elements
await gfg.forEachAsync(
  element => element.print()
);

Output:

"Tensor
    [10, 20, 30]"
"Tensor
    [40, 50, 60]"

Example 2: This time we will take 8 elements and try to split them in batches with 3 elements each.

Javascript




// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Creating an array
const gfg = tf.data.array(
  [10, 20, 30, 40, 50, 60, 70, 80]
).batch(3);
  
// Printing the elements
await gfg.forEachAsync(
  element => element.print()
);

Output:

"Tensor
    [10, 20, 30]"
"Tensor
    [40, 50, 60]"
"Tensor
    [70, 80]"

As the default value of smallLastBatch is true by default, we are seeing a 3rd batch with 2 elements.

Example 3: This time we will pass the smallLastBatch argument as false.

Javascript




// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Creating an array
const gfg = tf.data.array(
  [10, 20, 30, 40, 50, 60, 70, 80]
).batch(3, false);
  
// Printing the elements
await gfg.forEachAsync(
  element => element.print()
);

Output:

"Tensor
    [10, 20, 30]"
"Tensor
    [40, 50, 60]"

As the default value of smallLastBatch is false, we are not seeing the 3rd batch as there would only be 2 elements in the last batch which is less than 3, the specified batch size.

Reference: https://js.tensorflow.org/api/latest/#tf.data.Dataset.batch




My Personal Notes arrow_drop_up
Recommended Articles
Page :