Open In App

How to Reshape a Tensor in Tensorflow?

Tensor reshaping is the process of reshaping the order and total number of elements in tensors while only the shape is being changed. It is a fundamental operation in TensorFlow that allows you to change the shape of a tensor without changing its underlying data.

Using tf.reshape()

In TensorFlow, the tf.reshape() function is used to reshape tensors.



Syntax:

tf.reshape(tensor, shape, name=None)

Parameters:

tensor: The tensor which you want to change the shape of.



shape: The shape of the output tensor.

The name is an optional parameter that would allow you to set a name for the operation.

Datatypes In Tensorflow

A list of some commonly used data types in TensorFlow:

  1. tf.float16: 16-bit floating-point.
  2. tf.float32: 32-bit floating-point.
  3. tf.float64: 64-bit floating-point.
  4. tf.int8: 8-bit integer.
  5. tf.int16: 16-bit integer.
  6. tf.int32: 32-bit integer.
  7. tf.int64: 64-bit integer.
  8. tf.uint8: 8-bit unsigned integer.
  9. tf.bool: Boolean.
  10. tf.string: String.

These data types can be specified when creating tensors or operations using TensorFlow functions like tf.constant(), tf.Variable(), etc. By default, TensorFlow uses tf.float32 for floating-point numbers and tf.int32 for integers.

Importing Tensorflow




import tensorflow as tf

Initial Tensor Shape




t1 = tf.constant([[9, 7, 8],
                  [11, 4, 0]])
print('Tensor :\n', t1)
print('\nShape of Tensor:', tf.shape(t1).numpy())

Output:

Tensor :
tf.Tensor(
[[ 9 7 8]
[11 4 0]], shape=(2, 3), dtype=int32)
Shape of Tensor: [2 3]

Reshaping to a 1D Tensor




t2 = tf.reshape(t1, [6])
 
print('Tensor :\n', t2)
print('\nShape of Tensor:', tf.shape(t2).numpy())

Output:

Tensor :
tf.Tensor([ 9 7 8 11 4 0], shape=(6,), dtype=int32)
Shape of Tensor: [6]

Reshaping to a 2D Tensor

The code below reshapes the original tensor t2 into a 2D tensor with dimensions 1 row and 6 columns.




t3 = tf.reshape(t2, [1, 6])
 
print('Tensor :\n', t3)
print('\nShape of Tensor:', tf.shape(t3).numpy())

Output:

t3 = tf.reshape(t2, [1,6])
print('Tensor :\n',t3)
print('\nShape of Tensor:',tf.shape(t3).numpy())

Using Transposition Operations to Reshape

We use permuted dimensions to control the arrangement of dimensions in the output tensor, allowing for flexibility in tensor transformations such as transposition.




import tensorflow as tf
# Original tensor
t = tf.constant([[1, 2, 3],
                 [4, 5, 6]])
print('Original Tensor:')
print(t)
 
# Reshaping into a 3x2 tensor
reshaped_tensor = tf.reshape(t, [3, 2])
print("\nReshaped tensor:")
print(reshaped_tensor)
 
# Transposing with permuted dimensions
transposed_tensor = tf.transpose(t, perm=[1, 0])
print("\nTransposed tensor:")
print(transposed_tensor)

Output:

Original Tensor:
tf.Tensor(
[[1 2 3]
[4 5 6]], shape=(2, 3), dtype=int32)
Reshaped tensor:
tf.Tensor(
[[1 2]
[3 4]
[5 6]], shape=(3, 2), dtype=int32)
Transposed tensor:
tf.Tensor(
[[1 4]
[2 5]
[3 6]], shape=(3, 2), dtype=int32)

Using the special value -1 as shape




t = tf.constant([[1, 2, 3],
                [4, 5, 6]])
 
# Reshape into a 1D tensor (flattened)
result = tf.reshape(t, [-1])
 
# Print the result
print('Flattened Tensor:\n',result)
 
print('\n Original Tensor Shape:',tf.shape(t)) 
print('Flattened Tensor Shape:',tf.shape(result))

Output:

Flattened Tensor:
tf.Tensor([1 2 3 4 5 6], shape=(6,), dtype=int32)
Original Tensor Shape: tf.Tensor([2 3], shape=(2,), dtype=int32)
Flattened Tensor Shape: tf.Tensor([6], shape=(1,), dtype=int32)


Article Tags :