Open In App

Tensorflow.js tf.truncatedNormal() Function

Tensorflow.js is an open-source library that is developed by Google for running machine learning models as well as deep learning neural networks in the browser or node environment.

The .truncatedNormal() function is used to find tf.Tensor along with the values that are evaluated from a truncated normal distribution. Moreover, the values that are produced here as an output follows a normal distribution with the support of a stated mean and standard deviation, excluding those values whose size is greater than 2 standard deviations against the mean are discarded and chosen again.



Syntax :  

tf.truncatedNormal(shape, mean?, stdDev?, dtype?, seed?)

Parameters:



Return Value: It returns the tf.Tensor object.

Example 1:  




// Importing the tensorflow.js library
import * as tf from "@tensorflow/tfjs"
  
// Calling truncatedNormal() method and
// Printing output
tf.truncatedNormal([3, 4]).print();

Output:

Tensor
    [[-0.0277713, -0.4777073, -0.3911407, 1.85613   ],
     [-0.0667888, -0.0867875, 0.8295102 , -0.5933844],
     [0.5160138 , 0.7871808 , 0.6818511 , 1.2177598 ]]

Example 2:




// Importing the tensorflow.js library 
import * as tf from "@tensorflow/tfjs"
  
// Defining shape
var sh = [3, 2];
var mean = 4;
var st_dev = 5;
var dtyp = 'int32';
  
// Calling truncatedNormal() method
var res = tf.truncatedNormal(sh, mean, st_dev, dtyp);
  
// Printing output
res.print();

Output:

Tensor
    [[-1, -5],
     [4 , 4 ],
     [11, 2 ]]

Reference: https://js.tensorflow.org/api/latest/#truncatedNormal

Article Tags :