Open In App

tf.Module in Tensorflow Example

Last Updated : 22 Feb, 2024
Improve
Improve
Like Article
Like
Save
Share
Report

TensorFlow is an open-source library for data science. It provides various tools and APIs. One of the core components of TensorFlow is tf.Module, a class that represents a reusable piece of computation.

A tf.Module is an object that encapsulates a set of variables and functions that operate on them. A tf.Module can be used to define a layer, a model, or any other component of a neural network. A tf.Module can also be nested inside another tf.Module, creating a hierarchical structure of modules.

How to create tf.Module?

To create a tf.Module, we need to subclass the tf.Module class and define the variables and functions within the constructor or as class methods.

Example 1:

For example, the following code defines a simple linear layer as a tf.Module:

The following code, we have used ‘tf.Module’ class as a base class for creaating custom module called ‘Linear’.

  • The ‘Linear’ class inherits from tf.Module, that gains several functionalities provided by tf.Module class such as variable tracking and serialization support.
  • Inside the constructor, __init__, of the Linear class, we initialize the variables w and b using tf.Variable. These variables represent the weights and biases of the linear layer, respectively. The tf.Variable function creates TensorFlow variables that are trainable and can be updated during training.
  • The __call__ method of the Linear class defines how instances of this class behave when they are called. In this implementation, it takes an input tensor x, performs matrix multiplication of x with the weights w, and adds the bias b to the result.
  • By inheriting from tf.Module, the variables w and b are automatically tracked by TensorFlow. This means that they are registered with the module and can be accessed via the variables attribute of the Linear instance. This tracking is useful for tasks like saving and loading models, as well as accessing variables during training.
  • Because Linear is a subclass of tf.Module, instances of this class can be easily serialized using TensorFlow’s serialization mechanisms. This allows you to save the entire model, including its variables and structure, to disk for later use or deployment.

Python




import tensorflow as tf
 
class Linear(tf.Module):
  def __init__(self, units, input_dim):
    super().__init__()
    self.w = tf.Variable(tf.random.normal(shape=(input_dim, units)), name="weights")
    self.b = tf.Variable(tf.zeros(shape=(units,)), name="bias")
 
 
  def __call__(self, x):
    return tf.matmul(x, self.w) + self.b


The above code defines a custom linear layer using the tf.Module class in TensorFlow. A linear layer is a basic building block of neural networks that performs a linear transformation on the input. The code does the following:

A sample input tensor is created using tf.constant, representing a single data point with features.The created linear layer is applied to the input tensor, resulting in an output tensor representing the transformed data. This is the linear transformation formula: y = xw + b. Finally, the output tensor is printed to display the result of the linear transformation applied to the input data.

Here is an example of how to use the Linear class to create a linear layer with 3 output units and 2 input dimensions, and apply it to an input tensor of shape (1, 2):

Python




# Create a linear layer with 3 output units and 2 input dimensions
linear_layer = Linear(units=3, input_dim=2)
 
# Create an input tensor of shape (1, 2)
x = tf.constant([[1, 2]], dtype=tf.float32)
 
# Apply the linear layer to the input tensor
y = linear_layer(x)
 
# Print the output tensor
print(y)


Output:

tf.Tensor([[ 1.2542576  -0.08342385 -2.3886342 ]], shape=(1, 3), dtype=float32)

Example 2:

We can create our own tf.Module by subclassing it and defining our variables and functions. We can also use it as a container for other tf.Module objects.

Steps:

To create our own tf.Module, we need to follow these steps:

  • Import the TensorFlow library and the tf.Module class.
  • Define a subclass of tf.Module and give it a name.
  • In the init method, initialize any variables or submodules that your module will use.
  • Define one or more methods that perform the computation or logic of your module. You can use the variables and submodules defined in the init method, as well as any inputs or arguments passed to the methods.
  • Optionally, define any variables or tensors that you want to expose as attributes of your module using the @tf.Module.with_name_scope decorator.

Here is an example of a simple tf.Module that operates on a scalar tensor:

The SimpleModule class inherits from tf.Module. The constructor intializes two varaibles: trainable and non-trainable. The @tf.function decorator converts the method into tensorflow graph function and improve performance by optimizing the compuation graph.

Python




class SimpleModule(tf.Module):
    def __init__(self, name=None):
        super().__init__(name=name)
        self.a_variable = tf.Variable(5.0, name="train_me")
        self.non_trainable_variable = tf.Variable(
            5.0, trainable=False, name="do_not_train_me")
 
    @tf.function
    def __call__(self, x):
        return self.a_variable * x + self.non_trainable_variable
 
 
simple_module = SimpleModule(name="simple")
print(simple_module(tf.constant(5.0)))


Output:

30.0

The output of the code is 30.0. This is because the _call_ method of the SimpleModule class returns the product of self.a_variable and x plus self.non_trainable_variable. Since self.a_variable is initialized to 5.0, self.non_trainable_variable is initialized to 5.0, and x is passed as 5.0, the result is 5.0 * 5.0 + 5.0 = 30.0.

Advantages of tf.Module

The advantages of using tf.Module are:

  1. The module simplifies the management of variables and functions. A tf.Module automatically tracks the variables and functions that are defined within its scope and exposes them as attributes. This makes it easy to access and manipulate the internal state of a module.
  2. It supports checkpointing and saving. A tf.Module can be saved and restored using tf.train.Checkpoint, which preserves the values of the variables and the structure of the module. This is useful for saving and resuming the training process, or for exporting the model for inference.
  3. It supports serialization and deserialization. A tf.Module can be converted to and from a SavedModel, which is a standard format for storing and exchanging TensorFlow models. A SavedModel contains not only the variables and functions, but also the signatures and metadata of the module. This enables the model to be loaded and executed by different TensorFlow APIs, such as tf.keras, tf.lite, tf.js, or tf.serving.

Conclusion

tf.Module is a class in TensorFlow that helps us manage our variables and functions. It is useful for saving and restoring our model state, and for creating reusable components. It is a powerful and flexible class that enables the creation and reuse of computational components in TensorFlow. It simplifies the variable and function management, and supports checkpointing, saving, and serialization of the modules. By using tf.Module, we can build complex models with modular and readable code.



Like Article
Suggest improvement
Share your thoughts in the comments

Similar Reads