import * as tf from
"@tensorflow/tfjs"
;
const data = tf.tensor2d([70, 10, 62,
55, 74, 85, 66, 9], [4, 2]);
const kernel = tf.tensor2d([10, 82, 93, 83,
49, 73, 45, 77, 56, 29, 32, 2, 24,
39, 34, 91, 95, 61, 76, 69], [5, 4]);
const state = tf.tensor2d([29, 40, 79, 61,
5, 34, 78, 47, 86, 74, 46, 28], [4, 3]);
const output = tf.tensor2d([25, 55, 33, 85,
82, 65, 20, 75, 54, 59, 50, 3], [4, 3]);
const basicLSTMCell = tf.basicLSTMCell(1.0,
kernel, 2.0, data, state, output);
const input = tf.input({ shape: [4, 2] });
const simpleRNNLayer = tf.layers.simpleRNN({
units: 4,
returnSequences:
true
,
returnState:
true
,
cell: basicLSTMCell
});
let outputs, finalState;
[outputs, finalState] = simpleRNNLayer.apply(input);
const model = tf.model({
inputs: input,
outputs: outputs
});
const x = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8], [1, 4, 2]);
model.predict(x).print();