Skip to content
Related Articles

Related Articles

Save Article
Improve Article
Save Article
Like Article

Python – tensorflow.GradientTape.reset()

  • Last Updated : 10 Jul, 2020

TensorFlow is open-source Python library designed by Google to develop Machine Learning models and deep learning  neural networks. 

reset() is used to clear all information that is stored by the Tape.

 Attention geek! Strengthen your foundations with the Python Programming Foundation Course and learn the basics.  

To begin with, your interview preparations Enhance your Data Structures concepts with the Python DS Course. And to begin with your Machine Learning Journey, join the Machine Learning - Basic Level Course

Syntax: reset()



Parameters: It doesn’t accept any parameters.

Returns: It returns none.

Example 1:

Python3




# Importing the library
import tensorflow as tf
  
x = tf.constant(4.0)
  
# Using GradientTape
with tf.GradientTape() as gfg:
  gfg.watch(x)
  y = x * x * x
  y+=x*x
  
# Computing gradient without reset
res  = gfg.gradient(y, x) 
  
# Printing result
print("res(y = x*x*x + x*x): ",res)
  
# Using GradientTape
with tf.GradientTape() as gfg:
  gfg.watch(x)
  y = x * x * x
  
  # Resetting the Tape
  gfg.reset()
    
  gfg.watch(x)
  y+=x*x
  
# Computing gradient with reset
res  = gfg.gradient(y, x) 
  
# Printing result
print("res(y = x*x): ",res)

Output:


res(y = x*x*x + x*x):  tf.Tensor(56.0, shape=(), dtype=float32)
res(y = x*x):  tf.Tensor(8.0, shape=(), dtype=float32)

Example 2:

Python3




# Importing the library
import tensorflow as tf
  
x = tf.constant(3.0)
  
# Using GradientTape
with tf.GradientTape() as gfg:
  gfg.watch(x)
  y = x * x
  y+=x*x
  
# Computing gradient without reset
res  = gfg.gradient(y, x) 
  
# Printing result
print("res(y = x*x + x*x): ",res)
  
# Using GradientTape
with tf.GradientTape() as gfg:
  gfg.watch(x)
  y = x * x
  
  # Resetting the Tape
  gfg.reset()
  gfg.watch(x)
  y+=x
  
# Computing gradient with reset
res  = gfg.gradient(y, x) 
  
# Printing result
print("res(y = x): ",res)

Output:


res(y = x*x + x*x):  tf.Tensor(12.0, shape=(), dtype=float32)
res(y = x):  tf.Tensor(1.0, shape=(), dtype=float32)




My Personal Notes arrow_drop_up
Recommended Articles
Page :