Open In App

Tensorflow.js tf.layers setWeights() Method

Improve
Improve
Like Article
Like
Save
Share
Report

Tensorflow.js is an open-source library that is developed by Google for running machine learning models as well as deep learning neural networks in the browser or node environment.

The .setWeights() method is used to set the weights of the stated layer, from the given tensors.

Syntax:

setWeights(weights)

Parameters:

  • weights: It is the stated list of input tensors. It is of type tf.Tensor[]. Where, the count of arrays as well as their shape should be equivalent to the count of the dimensions of the stated weights of the layer used. In other words, it must be equal to the result of the getWeights() method.

Return Value: It returns void.

Example 1:

Javascript




// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Creating a model
const model = tf.sequential();
  
// Adding a layer
model.add(tf.layers.dense({units: 2, inputShape: [11]}));
  
// Calling setWeights() method
model.layers[0].setWeights([tf.truncatedNormal([11, 2]), tf.zeros([2])]);
  
// Compiling the model
model.compile({loss: 'categoricalCrossentropy', optimizer: 'sgd'});
  
// Printing output using getWeights() method
model.layers[0].getWeights()[0].print();


Output:

Tensor
    [[-0.5969906, -0.1883931],
     [0.8569255 , -0.49416  ],
     [0.1157023 , 0.1150239 ],
     [-0.4052143, 1.9936075 ],
     [0.3090054 , 0.7212474 ],
     [0.4626641 , -0.7287846],
     [0.4352857 , -0.5195332],
     [0.4626429 , 0.0216295 ],
     [-0.1110666, -0.5997615],
     [-0.5083916, -0.3582681],
     [-0.2847465, 1.184485  ]]

Here, truncatedNormal() method is used to create a tf.Tensor along with values that are sampled from a truncated normal distribution, zeros() method is used to create a tf.Tensor along with all the elements that are set to 0 and getWeights() method is used to print the weights that were set using setWeights() method.

Example 2:

Javascript




// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Creating a model
const model = tf.sequential();
  
// Adding layers
model.add(tf.layers.dense({units: 1, 
    inputShape: [5], batchSize: 1, dtype: 'int32'}));
model.add(tf.layers.dense({units: 2, inputShape: [6], batchSize: 5}));
model.add(tf.layers.dense({units: 3, inputShape: [7], batchSize: 8}));
model.add(tf.layers.dense({units: 4, inputShape: [8], batchSize: 12}));
  
// Calling setWeights() method
model.layers[0].setWeights([tf.ones([5, 1]), tf.zeros([1])]);
model.layers[1].setWeights([tf.ones([1, 2]), tf.zeros([2])]);
  
// Printing output using getWeights() method
model.layers[0].getWeights()[0].print();
model.layers[0].getWeights()[1].print();
model.layers[1].getWeights()[0].print();
model.layers[1].getWeights()[1].print();


Output:

Tensor
    [[1],
     [1],
     [1],
     [1],
     [1]]
Tensor
    [0]
Tensor
     [[1, 1],]
Tensor
    [0, 0]

Here, ones() method is used to create a tf.Tensor along with all the elements that are set to 1.

Reference: https://js.tensorflow.org/api/latest/#tf.layers.Layer.setWeights



Last Updated : 22 Apr, 2022
Like Article
Save Article
Previous
Next
Share your thoughts in the comments
Similar Reads