Open In App

Tensorflow.js tf.layers.lstmCell() Function

Tensorflow.js is an open-source library that is being developed by Google for running machine learning models as well as deep learning neural networks in the browser or node environment.

The layers.lstmCell() function is used for Cell class for LSTM. It is separate from the RNN subclass.



Syntax:

tf.layers.lstmCell(args)

Parameters: The function contains an args object that contains the following parameters:



Return value: It returns LSTMCell.

 

Example 1:




// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Calling the layers.lstmCell 
// function and printing the output
const cell = tf.layers.lstmCell({units: 3});
const input = tf.input({shape: [120]});
const output = cell.apply(input);
  
console.log(JSON.stringify(output.shape));

Output:

[null, 120]

Example 2:




// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
const cells = [
   tf.layers.lstmCell({units: 6}),
   tf.layers.lstmCell({units: 10}),
];
const rnn = tf.layers.rnn(
  {cell: cells, returnSequences: true}
);
  
// Create an input with 10 time steps 
// and a length-20 vector at each step
const input = tf.input({shape: [40, 60]});
const output = rnn.apply(input);
  
console.log(JSON.stringify(output.shape));

Output:

[null, 30, 8]

Reference: https://js.tensorflow.org/api/latest/#layers.lstmCell


Article Tags :