Open In App

Tensorflow.js tf.data.Dataset class .batch() Method

Tensorflow.js is an open-source library developed by Google for running machine learning models and deep learning neural networks in the browser or node environment. It also helps the developers to develop ML models in JavaScript language and can use ML directly in the browser or in Node.js.

The tf.data.Dataset.batch() function is used to group the elements into batches.



Syntax:

tf.data.Dataset.batch(batchSize, smallLastBatch?)

Parameters:



Return value: It returns a tf.data.Dataset.

Example 1: In this example we will take an array if size 6 and split it into batches with 3 elements with each.




// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Creating an array
const gfg = tf.data.array(
  [10, 20, 30, 40, 50, 60]
).batch(3);
  
// Printing the elements
await gfg.forEachAsync(
  element => element.print()
);

Output:

"Tensor
    [10, 20, 30]"
"Tensor
    [40, 50, 60]"

Example 2: This time we will take 8 elements and try to split them in batches with 3 elements each.




// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Creating an array
const gfg = tf.data.array(
  [10, 20, 30, 40, 50, 60, 70, 80]
).batch(3);
  
// Printing the elements
await gfg.forEachAsync(
  element => element.print()
);

Output:

"Tensor
    [10, 20, 30]"
"Tensor
    [40, 50, 60]"
"Tensor
    [70, 80]"

As the default value of smallLastBatch is true by default, we are seeing a 3rd batch with 2 elements.

Example 3: This time we will pass the smallLastBatch argument as false.




// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Creating an array
const gfg = tf.data.array(
  [10, 20, 30, 40, 50, 60, 70, 80]
).batch(3, false);
  
// Printing the elements
await gfg.forEachAsync(
  element => element.print()
);

Output:

"Tensor
    [10, 20, 30]"
"Tensor
    [40, 50, 60]"

As the default value of smallLastBatch is false, we are not seeing the 3rd batch as there would only be 2 elements in the last batch which is less than 3, the specified batch size.

Reference: https://js.tensorflow.org/api/latest/#tf.data.Dataset.batch


Article Tags :