Open In App

Tensorflow.js tf.data.Dataset Class

Tensorflow.js is an open-source library developed by Google for running machine learning models and deep learning neural networks in the browser or node environment.

Tensorflow.js tf.data.Dataset class represents the large collection of data. Data can be an array, a map, or primitive, or can be any nested structure of this data type. The data type is in some ordered collection of data. We can perform the various methods on these data. This method results in a new dataset. 



Syntax: 

tf.data.method(args);

Parameters: This method accepts the following parameter:



Return Value: It returns a dataset of the same type as input.

Example 1:  In this example, we will see the batch( ) method. It groups elements in batches. It accepts two arguments batchSize which is the length of the group and smallLastBatch which is a boolean that tells print last batch if it is small in length. 




// Importing the tensorflow.js library
const tf = require("@tensorflow/tfjs")
 
// Creating tensor dataset of array
const array = [[1, 4, 6], [2, 5, 7],
    [3, 6, 8], [4, 7, 9], [5, 8, 11]];
 
// Making dataset with array
const gfg = tf.data.array(array);
 
// Creating new dataset
const GFG = gfg.batch(2);
 
// Printing Dataset
GFG.forEachAsync(Q => Q.print());

Output:

Tensor
    [[1, 4, 6],
     [2, 5, 7]]
Tensor
    [[3, 6, 8],
     [4, 7, 9]]
Tensor
     [[5, 8, 11],]

Example 2: In this example, we will see the shuffle() method. It is used to shuffle elements of the dataset. It takes buffer size which is the number from which shuffle element will start, and seed is used to create a random seed for distribution of element, or the last argument is reshuffle each_iteration which is a boolean it tells to reshuffle randomly when it iterator over the dataset.




// Importing the tensorflow.js library
const tf = require("@tensorflow/tfjs")
 
// Creating Dataset  
const gfg_array = [4, 8, 12, 16, 20];
const gfg = tf.data.array(gfg_array);
 
// Prefetching data from dataset
const GFG_array = gfg.shuffle(5);
 
// Printing data from array
GFG_array.forEachAsync( tm => console.log(tm))

Output:

8
20
4
16
12

Reference: https://js.tensorflow.org/api/latest/#class:data.Dataset


Article Tags :