Open In App

U-Net Architecture Explained

Last Updated : 08 Jun, 2023
Improve
Improve
Like Article
Like
Save
Share
Report

U-Net is a widely used deep learning architecture that was first introduced in the “U-Net: Convolutional Networks for Biomedical Image Segmentation” paper. The primary purpose of this architecture was to address the challenge of limited annotated data in the medical field. This network was designed to effectively leverage a smaller amount of data while maintaining speed and accuracy.

U-Net Architecture:

The architecture of U-Net is unique in that it consists of a contracting path and an expansive path. The contracting path contains encoder layers that capture contextual information and reduce the spatial resolution of the input, while the expansive path contains decoder layers that decode the encoded data and use the information from the contracting path via skip connections to generate a segmentation map.

The contracting path in U-Net is responsible for identifying the relevant features in the input image. The encoder layers perform convolutional operations that reduce the spatial resolution of the feature maps while increasing their depth, thereby capturing increasingly abstract representations of the input. This contracting path is similar to the feedforward layers in other convolutional neural networks. On the other hand, the expansive path works on decoding the encoded data and locating the features while maintaining the spatial resolution of the input. The decoder layers in the expansive path upsample the feature maps, while also performing convolutional operations. The skip connections from the contracting path help to preserve the spatial information lost in the contracting path, which helps the decoder layers to locate the features more accurately.

U-Net Architecture -geeksforgeeks

U-Net Architecture

Figure 1 illustrates how the U-Net network converts a grayscale input image of size 572×572×1 into a binary segmented output map of size 388×388×2. We can notice that the output size is smaller than the input size because no padding is being used. However, if we use padding, we can maintain the input size. During the contracting path, the input image is progressively reduced in height and width but increased in the number of channels. This increase in channels allows the network to capture high-level features as it progresses down the path. At the bottleneck, a final convolution operation is performed to generate a 30×30×1024 shaped feature map. The expansive path then takes the feature map from the bottleneck and converts it back into an image of the same size as the original input. This is done using upsampling layers, which increase the spatial resolution of the feature map while reducing the number of channels. The skip connections from the contracting path are used to help the decoder layers locate and refine the features in the image. Finally, each pixel in the output image represents a label that corresponds to a particular object or class in the input image. In this case, the output map is a binary segmentation map where each pixel represents a foreground or background region.

Build the Model:

Next, we will implement the U-Net architecture using Python 3 and the TensorFlow library. The implementation can be divided into three parts. First, we will define the encoder block used in the contraction path. This block consists of two 3×3 convolution layers followed by a ReLU activation layer and a 2×2 max pooling layer. The second part is the decoder block, which takes the feature map from the lower layer, upconverts it, crops and concatenates it with the encoder data of the same level, and then performs two 3×3 convolution layers followed by ReLU activation. The third part is defining the U-Net model using these blocks.

Encoder

Here’s the code for the encoder block:

Python3




def encoder_block(inputs, num_filters):
  
    # Convolution with 3x3 filter followed by ReLU activation
    x = tf.keras.layers.Conv2D(num_filters, 
                               3
                               padding = 'valid')(inputs)
    x = tf.keras.layers.Activation('relu')(x)
      
    # Convolution with 3x3 filter followed by ReLU activation
    x = tf.keras.layers.Conv2D(num_filters, 
                               3
                               padding = 'valid')(x)
    x = tf.keras.layers.Activation('relu')(x)
  
    # Max Pooling with 2x2 filter
    x = tf.keras.layers.MaxPool2D(pool_size = (2, 2),
                                  strides = 2)(x)
      
    return x


Decoder

Now defining the decoder.

Python3




def decoder_block(inputs, skip_features, num_filters):
  
    # Upsampling with 2x2 filter
    x = tf.keras.layers.Conv2DTranspose(num_filters,
                                        (2, 2), 
                                        strides = 2
                                        padding = 'valid')(inputs)
      
    # Copy and crop the skip features 
    # to match the shape of the upsampled input
    skip_features = tf.image.resize(skip_features,
                                    size = (x.shape[1],
                                            x.shape[2]))
    x = tf.keras.layers.Concatenate()([x, skip_features])
      
    # Convolution with 3x3 filter followed by ReLU activation
    x = tf.keras.layers.Conv2D(num_filters,
                               3
                               padding = 'valid')(x)
    x = tf.keras.layers.Activation('relu')(x)
  
    # Convolution with 3x3 filter followed by ReLU activation
    x = tf.keras.layers.Conv2D(num_filters, 3, padding = 'valid')(x)
    x = tf.keras.layers.Activation('relu')(x)
      
    return x


U-Net

Using these blocks and defining a U-Net model and printing model summary.

Python3




# Unet code
import tensorflow as tf
  
def unet_model(input_shape = (256, 256, 3), num_classes = 1):
    inputs = tf.keras.layers.Input(input_shape)
      
    # Contracting Path
    s1 = encoder_block(inputs, 64)
    s2 = encoder_block(s1, 128)
    s3 = encoder_block(s2, 256)
    s4 = encoder_block(s3, 512)
      
    # Bottleneck
    b1 = tf.keras.layers.Conv2D(1024, 3, padding = 'valid')(s4)
    b1 = tf.keras.layers.Activation('relu')(b1)
    b1 = tf.keras.layers.Conv2D(1024, 3, padding = 'valid')(b1)
    b1 = tf.keras.layers.Activation('relu')(b1)
      
    # Expansive Path
    s5 = decoder_block(b1, s4, 512)
    s6 = decoder_block(s5, s3, 256)
    s7 = decoder_block(s6, s2, 128)
    s8 = decoder_block(s7, s1, 64)
      
    # Output
    outputs = tf.keras.layers.Conv2D(num_classes, 
                                     1
                                     padding = 'valid'
                                     activation = 'sigmoid')(s8)
      
    model = tf.keras.models.Model(inputs = inputs, 
                                  outputs = outputs, 
                                  name = 'U-Net')
    return model
  
if __name__ == '__main__':
    model = unet_model(input_shape=(572, 572, 3), num_classes=2)
    model.summary()


Output:

Model: "U-Net"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_6 (InputLayer)           [(None, 572, 572, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv2d_95 (Conv2D)             (None, 570, 570, 64  1792        ['input_6[0][0]']                
                                )                                                                 
                                                                                                  
 activation_90 (Activation)     (None, 570, 570, 64  0           ['conv2d_95[0][0]']              
                                )                                                                 
                                                                                                  
 conv2d_96 (Conv2D)             (None, 568, 568, 64  36928       ['activation_90[0][0]']          
                                )                                                                 
                                                                                                  
 activation_91 (Activation)     (None, 568, 568, 64  0           ['conv2d_96[0][0]']              
                                )                                                                 
                                                                                                  
 max_pooling2d_20 (MaxPooling2D  (None, 284, 284, 64  0          ['activation_91[0][0]']          
 )                              )                                                                 
                                                                                                  
 conv2d_97 (Conv2D)             (None, 282, 282, 12  73856       ['max_pooling2d_20[0][0]']       
                                8)                                                                
                                                                                                  
 activation_92 (Activation)     (None, 282, 282, 12  0           ['conv2d_97[0][0]']              
                                8)                                                                
                                                                                                  
 conv2d_98 (Conv2D)             (None, 280, 280, 12  147584      ['activation_92[0][0]']          
                                8)                                                                
                                                                                                  
 activation_93 (Activation)     (None, 280, 280, 12  0           ['conv2d_98[0][0]']              
                                8)                                                                
                                                                                                  
 max_pooling2d_21 (MaxPooling2D  (None, 140, 140, 12  0          ['activation_93[0][0]']          
 )                              8)                                                                
                                                                                                  
 conv2d_99 (Conv2D)             (None, 138, 138, 25  295168      ['max_pooling2d_21[0][0]']       
                                6)                                                                
                                                                                                  
 activation_94 (Activation)     (None, 138, 138, 25  0           ['conv2d_99[0][0]']              
                                6)                                                                
                                                                                                  
 conv2d_100 (Conv2D)            (None, 136, 136, 25  590080      ['activation_94[0][0]']          
                                6)                                                                
                                                                                                  
 activation_95 (Activation)     (None, 136, 136, 25  0           ['conv2d_100[0][0]']             
                                6)                                                                
                                                                                                  
 max_pooling2d_22 (MaxPooling2D  (None, 68, 68, 256)  0          ['activation_95[0][0]']          
 )                                                                                                
                                                                                                  
 conv2d_101 (Conv2D)            (None, 66, 66, 512)  1180160     ['max_pooling2d_22[0][0]']       
                                                                                                  
 activation_96 (Activation)     (None, 66, 66, 512)  0           ['conv2d_101[0][0]']             
                                                                                                  
 conv2d_102 (Conv2D)            (None, 64, 64, 512)  2359808     ['activation_96[0][0]']          
                                                                                                  
 activation_97 (Activation)     (None, 64, 64, 512)  0           ['conv2d_102[0][0]']             
                                                                                                  
 max_pooling2d_23 (MaxPooling2D  (None, 32, 32, 512)  0          ['activation_97[0][0]']          
 )                                                                                                
                                                                                                  
 conv2d_103 (Conv2D)            (None, 30, 30, 1024  4719616     ['max_pooling2d_23[0][0]']       
                                )                                                                 
                                                                                                  
 activation_98 (Activation)     (None, 30, 30, 1024  0           ['conv2d_103[0][0]']             
                                )                                                                 
                                                                                                  
 conv2d_104 (Conv2D)            (None, 28, 28, 1024  9438208     ['activation_98[0][0]']          
                                )                                                                 
                                                                                                  
 activation_99 (Activation)     (None, 28, 28, 1024  0           ['conv2d_104[0][0]']             
                                )                                                                 
                                                                                                  
 conv2d_transpose_20 (Conv2DTra  (None, 56, 56, 512)  2097664    ['activation_99[0][0]']          
 nspose)                                                                                          
                                                                                                  
 tf.image.resize_20 (TFOpLambda  (None, 56, 56, 512)  0          ['max_pooling2d_23[0][0]']       
 )                                                                                                
                                                                                                  
 concatenate_20 (Concatenate)   (None, 56, 56, 1024  0           ['conv2d_transpose_20[0][0]',    
                                )                                 'tf.image.resize_20[0][0]']     
                                                                                                  
 conv2d_105 (Conv2D)            (None, 54, 54, 512)  4719104     ['concatenate_20[0][0]']         
                                                                                                  
 activation_100 (Activation)    (None, 54, 54, 512)  0           ['conv2d_105[0][0]']             
                                                                                                  
 conv2d_106 (Conv2D)            (None, 52, 52, 512)  2359808     ['activation_100[0][0]']         
                                                                                                  
 activation_101 (Activation)    (None, 52, 52, 512)  0           ['conv2d_106[0][0]']             
                                                                                                  
 conv2d_transpose_21 (Conv2DTra  (None, 104, 104, 25  524544     ['activation_101[0][0]']         
 nspose)                        6)                                                                
                                                                                                  
 tf.image.resize_21 (TFOpLambda  (None, 104, 104, 25  0          ['max_pooling2d_22[0][0]']       
 )                              6)                                                                
                                                                                                  
 concatenate_21 (Concatenate)   (None, 104, 104, 51  0           ['conv2d_transpose_21[0][0]',    
                                2)                                'tf.image.resize_21[0][0]']     
                                                                                                  
 conv2d_107 (Conv2D)            (None, 102, 102, 25  1179904     ['concatenate_21[0][0]']         
                                6)                                                                
                                                                                                  
 activation_102 (Activation)    (None, 102, 102, 25  0           ['conv2d_107[0][0]']             
                                6)                                                                
                                                                                                  
 conv2d_108 (Conv2D)            (None, 100, 100, 25  590080      ['activation_102[0][0]']         
                                6)                                                                
                                                                                                  
 activation_103 (Activation)    (None, 100, 100, 25  0           ['conv2d_108[0][0]']             
                                6)                                                                
                                                                                                  
 conv2d_transpose_22 (Conv2DTra  (None, 200, 200, 12  131200     ['activation_103[0][0]']         
 nspose)                        8)                                                                
                                                                                                  
 tf.image.resize_22 (TFOpLambda  (None, 200, 200, 12  0          ['max_pooling2d_21[0][0]']       
 )                              8)                                                                
                                                                                                  
 concatenate_22 (Concatenate)   (None, 200, 200, 25  0           ['conv2d_transpose_22[0][0]',    
                                6)                                'tf.image.resize_22[0][0]']     
                                                                                                  
 conv2d_109 (Conv2D)            (None, 198, 198, 12  295040      ['concatenate_22[0][0]']         
                                8)                                                                
                                                                                                  
 activation_104 (Activation)    (None, 198, 198, 12  0           ['conv2d_109[0][0]']             
                                8)                                                                
                                                                                                  
 conv2d_110 (Conv2D)            (None, 196, 196, 12  147584      ['activation_104[0][0]']         
                                8)                                                                
                                                                                                  
 activation_105 (Activation)    (None, 196, 196, 12  0           ['conv2d_110[0][0]']             
                                8)                                                                
                                                                                                  
 conv2d_transpose_23 (Conv2DTra  (None, 392, 392, 64  32832      ['activation_105[0][0]']         
 nspose)                        )                                                                 
                                                                                                  
 tf.image.resize_23 (TFOpLambda  (None, 392, 392, 64  0          ['max_pooling2d_20[0][0]']       
 )                              )                                                                 
                                                                                                  
 concatenate_23 (Concatenate)   (None, 392, 392, 12  0           ['conv2d_transpose_23[0][0]',    
                                8)                                'tf.image.resize_23[0][0]']     
                                                                                                  
 conv2d_111 (Conv2D)            (None, 390, 390, 64  73792       ['concatenate_23[0][0]']         
                                )                                                                 
                                                                                                  
 activation_106 (Activation)    (None, 390, 390, 64  0           ['conv2d_111[0][0]']             
                                )                                                                 
                                                                                                  
 conv2d_112 (Conv2D)            (None, 388, 388, 64  36928       ['activation_106[0][0]']         
                                )                                                                 
                                                                                                  
 activation_107 (Activation)    (None, 388, 388, 64  0           ['conv2d_112[0][0]']             
                                )                                                                 
                                                                                                  
 conv2d_113 (Conv2D)            (None, 388, 388, 2)  130         ['activation_107[0][0]']         
                                                                                                  
==================================================================================================
Total params: 31,031,810
Trainable params: 31,031,810
Non-trainable params: 0
__________________________________________________________________________________________________

Apply to an Image

Input Image: 

Input image-Geeksforgeeks

Input image

Python3




import numpy as np
from PIL import Image
from tensorflow.keras.preprocessing import image
  
# Load the image
img = Image.open('cat.png')
# Preprocess the image
img = img.resize((572, 572))
img_array = image.img_to_array(img)
img_array = np.expand_dims(img_array[:,:,:3], axis=0)
img_array = img_array / 255.
  
# Load the model
model = unet_model(input_shape=(572, 572, 3), num_classes=2)
  
# Make predictions
predictions = model.predict(img_array)
  
# Convert predictions to a numpy array and resize to original image size
predictions = np.squeeze(predictions, axis=0)
predictions = np.argmax(predictions, axis=-1)
predictions = Image.fromarray(np.uint8(predictions*255))
predictions = predictions.resize((img.width, img.height))
  
# Save the predicted image
predictions.save('predicted_image.jpg')
predictions


Output:

1/1 [==============================] - 2s 2s/step
Predicted Image-Geeksforgeeks

Predicted Image

Applications:

The versatility and flexibility has enabled us to use this idea in various other domains beyond biomedical image segmentation. Some of the popular application involves image denoising, image-to-image translation, image super-resolution, object detection, and NLP. You can also explore some of these applications in the following articles:

  1. Image Segmentation Using TensorFlow
  2. Image-to-Image Translation using Pix2Pix


Like Article
Suggest improvement
Previous
Next
Share your thoughts in the comments

Similar Reads