Skip to content
Related Articles

Related Articles

Tensorflow.js tf.data.Dataset Class

Improve Article
Save Article
Like Article
  • Last Updated : 30 Nov, 2021

Tensorflow.js is an open-source library developed by Google for running machine learning models and deep learning neural networks in the browser or node environment.

Tensorflow.js tf.data.Dataset class represents the large collection of data. Data can be an array, a map, or primitive, or can be any nested structure of this data type. The data type is in some ordered collection of data. We can perform the various methods on these data. This method results in a new dataset. 

Syntax: 

tf.data.method(args);

Parameters: This method accepts the following parameter:

  • args: It is not the same for all methods it can be different for different methods.

Return Value: It returns a dataset of the same type as input.

Example 1:  In this example, we will see the batch( ) method. It groups elements in batches. It accepts two arguments batchSize which is the length of the group and smallLastBatch which is a boolean that tells print last batch if it is small in length. 

Javascript




// Importing the tensorflow.js library
const tf = require("@tensorflow/tfjs")
 
// Creating tenst dataset of array
const array = [[1, 4, 6], [2, 5, 7],
    [3, 6, 8], [4, 7, 9], [5, 8, 11]];
 
// Making dataset with array
const gfg = tf.data.array(array);
 
// Creating new dataset
const GFG = gfg.batch(2);
 
// Printing Dataset
GFG.forEachAsync(Q => Q.print());

Output:

Tensor
    [[1, 4, 6],
     [2, 5, 7]]
Tensor
    [[3, 6, 8],
     [4, 7, 9]]
Tensor
     [[5, 8, 11],]

Example 2: In this example, we will see the shuffle() method. It is used to shuffle elements of the dataset. It takes buffer size which is the number from which shuffle element will start, and seed is used to create a random seed for distribution of element, or the last argument is reshuffle each_iteration which is a boolean it tells to reshuffle randomly when it iterator over the dataset.

Javascript




// Importing the tensorflow.js library
const tf = require("@tensorflow/tfjs")
 
// Creating Dataset  
const gfg_array = [4, 8, 12, 16, 20];
const gfg = tf.data.array(gfg_array);
 
// Prefetching data from dataset
const GFG_array = gfg.shuffle(5);
 
// Printing data from array
GFG_array.forEachAsync( tm => console.log(tm))

Output:

8
20
4
16
12

Reference: https://js.tensorflow.org/api/latest/#class:data.Dataset


My Personal Notes arrow_drop_up
Recommended Articles
Page :

Start Your Coding Journey Now!