Skip to content
Related Articles

Related Articles

Improve Article

Tensorflow.js tf.layers setWeights() Method

  • Last Updated : 22 Jul, 2021

Tensorflow.js is an open-source library that is developed by Google for running machine learning models as well as deep learning neural networks in the browser or node environment.

The .setWeights() method is used to set the weights of the stated layer, from the given tensors.

Syntax:

setWeights(weights)

Parameters:

  • weights: It is the stated list of input tensors. It is of type tf.Tensor[]. Where, the count of arrays as well as their shape should be equivalent to the count of the dimensions of the stated weights of the layer used. In other words, it must be equal to the result of the getWeights() method.

Return Value: It returns void.



Example 1:

Javascript




// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Creating a model
const model = tf.sequential();
  
// Adding a layer
model.add(tf.layers.dense({units: 2, inputShape: [11]}));
  
// Calling setWeights() method
model.layers[0].setWeights([tf.truncatedNormal([11, 2]), tf.zeros([2])]);
  
// Compiling the model
model.compile({loss: 'categoricalCrossentropy', optimizer: 'sgd'});
  
// Printing output using getWeights() method
model.layers[0].getWeights()[0].print();

Output:

Tensor
    [[-0.5969906, -0.1883931],
     [0.8569255 , -0.49416  ],
     [0.1157023 , 0.1150239 ],
     [-0.4052143, 1.9936075 ],
     [0.3090054 , 0.7212474 ],
     [0.4626641 , -0.7287846],
     [0.4352857 , -0.5195332],
     [0.4626429 , 0.0216295 ],
     [-0.1110666, -0.5997615],
     [-0.5083916, -0.3582681],
     [-0.2847465, 1.184485  ]]

Here, truncatedNormal() method is used to create a tf.Tensor along with values that are sampled from a truncated normal distribution, zeros() method is used to create a tf.Tensor along with all the elements that are set to 0 and getWeights() method is used to print the weights that were set using setWeights() method.

Example 2:

Javascript




// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Creating a model
const model = tf.sequential();
  
// Adding layers
model.add(tf.layers.dense({units: 1, 
    inputShape: [5], batchSize: 1, dtype: 'int32'}));
model.add(tf.layers.dense({units: 2, inputShape: [6], batchSize: 5}));
model.add(tf.layers.dense({units: 3, inputShape: [7], batchSize: 8}));
model.add(tf.layers.dense({units: 4, inputShape: [8], batchSize: 12}));
  
// Calling setWeights() method
model.layers[0].setWeights([tf.ones([5, 1]), tf.zeros([1])]);
model.layers[1].setWeights([tf.ones([1, 2]), tf.zeros([2])]);
  
// Printing output using getWeights() method
model.layers[0].getWeights()[0].print();
model.layers[0].getWeights()[1].print();
model.layers[1].getWeights()[0].print();
model.layers[1].getWeights()[1].print();

Output:

Tensor
    [[1],
     [1],
     [1],
     [1],
     [1]]
Tensor
    [0]
Tensor
     [[1, 1],]
Tensor
    [0, 0]

Here, ones() method is used to create a tf.Tensor along with all the elements that are set to 1.

Reference: https://js.tensorflow.org/api/latest/#tf.layers.Layer.setWeights




My Personal Notes arrow_drop_up
Recommended Articles
Page :