Open In App

tf.Module in Tensorflow Example

TensorFlow is an open-source library for data science. It provides various tools and APIs. One of the core components of TensorFlow is tf.Module, a class that represents a reusable piece of computation.

A tf.Module is an object that encapsulates a set of variables and functions that operate on them. A tf.Module can be used to define a layer, a model, or any other component of a neural network. A tf.Module can also be nested inside another tf.Module, creating a hierarchical structure of modules.

How to create tf.Module?

To create a tf.Module, we need to subclass the tf.Module class and define the variables and functions within the constructor or as class methods.

Example 1:

For example, the following code defines a simple linear layer as a tf.Module:

The following code, we have used ‘tf.Module’ class as a base class for creaating custom module called ‘Linear’.

import tensorflow as tf
class Linear(tf.Module):
  def __init__(self, units, input_dim):
    self.w = tf.Variable(tf.random.normal(shape=(input_dim, units)), name="weights")
    self.b = tf.Variable(tf.zeros(shape=(units,)), name="bias")
  def __call__(self, x):
    return tf.matmul(x, self.w) + self.b

The above code defines a custom linear layer using the tf.Module class in TensorFlow. A linear layer is a basic building block of neural networks that performs a linear transformation on the input. The code does the following:

A sample input tensor is created using tf.constant, representing a single data point with features.The created linear layer is applied to the input tensor, resulting in an output tensor representing the transformed data. This is the linear transformation formula: y = xw + b. Finally, the output tensor is printed to display the result of the linear transformation applied to the input data.

Here is an example of how to use the Linear class to create a linear layer with 3 output units and 2 input dimensions, and apply it to an input tensor of shape (1, 2):

# Create a linear layer with 3 output units and 2 input dimensions
linear_layer = Linear(units=3, input_dim=2)
# Create an input tensor of shape (1, 2)
x = tf.constant([[1, 2]], dtype=tf.float32)
# Apply the linear layer to the input tensor
y = linear_layer(x)
# Print the output tensor


tf.Tensor([[ 1.2542576  -0.08342385 -2.3886342 ]], shape=(1, 3), dtype=float32)

Example 2:

We can create our own tf.Module by subclassing it and defining our variables and functions. We can also use it as a container for other tf.Module objects.


To create our own tf.Module, we need to follow these steps:

Here is an example of a simple tf.Module that operates on a scalar tensor:

The SimpleModule class inherits from tf.Module. The constructor intializes two varaibles: trainable and non-trainable. The @tf.function decorator converts the method into tensorflow graph function and improve performance by optimizing the compuation graph.

class SimpleModule(tf.Module):
    def __init__(self, name=None):
        self.a_variable = tf.Variable(5.0, name="train_me")
        self.non_trainable_variable = tf.Variable(
            5.0, trainable=False, name="do_not_train_me")
    def __call__(self, x):
        return self.a_variable * x + self.non_trainable_variable
simple_module = SimpleModule(name="simple")



The output of the code is 30.0. This is because the _call_ method of the SimpleModule class returns the product of self.a_variable and x plus self.non_trainable_variable. Since self.a_variable is initialized to 5.0, self.non_trainable_variable is initialized to 5.0, and x is passed as 5.0, the result is 5.0 * 5.0 + 5.0 = 30.0.

Advantages of tf.Module

The advantages of using tf.Module are:

  1. The module simplifies the management of variables and functions. A tf.Module automatically tracks the variables and functions that are defined within its scope and exposes them as attributes. This makes it easy to access and manipulate the internal state of a module.
  2. It supports checkpointing and saving. A tf.Module can be saved and restored using tf.train.Checkpoint, which preserves the values of the variables and the structure of the module. This is useful for saving and resuming the training process, or for exporting the model for inference.
  3. It supports serialization and deserialization. A tf.Module can be converted to and from a SavedModel, which is a standard format for storing and exchanging TensorFlow models. A SavedModel contains not only the variables and functions, but also the signatures and metadata of the module. This enables the model to be loaded and executed by different TensorFlow APIs, such as tf.keras, tf.lite, tf.js, or tf.serving.


tf.Module is a class in TensorFlow that helps us manage our variables and functions. It is useful for saving and restoring our model state, and for creating reusable components. It is a powerful and flexible class that enables the creation and reuse of computational components in TensorFlow. It simplifies the variable and function management, and supports checkpointing, saving, and serialization of the modules. By using tf.Module, we can build complex models with modular and readable code.

Article Tags :