Transfer learning as a general term refers to reusing the knowledge learned from one task for another. Specifically for convolutional neural networks (CNNs), many image features are common to a variety of datasets (e.g. lines, edges are seen in almost every image). It is for this reason that, especially for large structures, CNNs are very rarely trained completely from scratch as large datasets and heavy computational resources are hard to come by.
A common pretraining dataset used is the ImageNet dataset, consisting of 1.2 million images. The actual model used varies from task to task (many times, people just choose what performs best on the ImageNet challenge), but ResNet50 model in used this article. The pre-trained model can often be found through whatever library is being used which, in this case, is Keras.
ResNet was initially designed as a method to solve the vanishing gradient problem. This is a problem where backpropagated gradients become extremely small as they’re multiplied over and over again, limiting the size of a neural network. The ResNet architecture attempts to solve that by employing skip connections, that is adding shortcuts that allow data to skip past layers.
The model consists of a series of convolutional layers + skip connections, then average pooling, then an output fully connected (dense) layer. For transfer learning, we only want the convolutional layers as those to contain the features we’re interested in, so we would want to omit them when importing the model. Finally, because we’re removing the output layers, we then need to replace them with our own series of layers.
To show the process of transfer learning, I’ll be using the Caltech-101 dataset, an image dataset with 101 categories and about 40-800 images per category.
First download and extract the dataset here. Make sure to remove the “BACKGROUND_Google” folder after extraction.
Code : To properly evaluate, we need to split the data into training and testing sets as well. Here, we need to split within each category to ensure proper representation in the test set.
This above code creates the file structure: 101_ObjectCategories/ -- accordion -- airplanes -- anchor -- ... caltech_test/ -- accordion -- airplanes -- anchor -- ...
The first folder contains the train images, the second contains test images. Each subfolder includes images belonging to that category. To input the data, we’re going to use Keras’s ImageDataGenerator class. ImageDataGenerator allows for the easy processing of image data, having options for augmentation as well.
The above code takes the file path of the image directory and creates an object for data generation.
Code : To add the base pretrained model.
This dataset is relatively small at around 5628 images after splitting, with most categories having only 50 images, so fine-tuning the convolutional layers may result in overfitting. Our new dataset is pretty similar to the ImageNet dataset, so we can be confident that a lot of the pre-trained weights have the correct features as well. So, we can freeze those trained convolutional layers so they aren’t changed when we train the rest of the classifier. If you have a smaller dataset that is significantly different from the original, fine-tuning may still cause overfitting, but the later layers wouldn’t contain the correct features. So, you could again freeze the convolutional layers but only use the output from earlier layers as those contain more general features. With a large dataset, you don’t need to worry about overfitting, so you can often fine-tune the entire network.
Now, we can add the rest of the classifier. This takes the output from the pre-trained convolutional layers and inputs it into a separate classifier that gets trained on the new dataset.
Code : Train the model
Epoch 1/5 176/176 [==============================] - 27s 156ms/step - loss: 1.6601 - acc: 0.6338 - val_loss: 0.3799 - val_acc: 0.8922 Epoch 2/5 176/176 [==============================] - 19s 107ms/step - loss: 0.4637 - acc: 0.8696 - val_loss: 0.2841 - val_acc: 0.9225 Epoch 3/5 176/176 [==============================] - 19s 107ms/step - loss: 0.2777 - acc: 0.9211 - val_loss: 0.2714 - val_acc: 0.9225 Epoch 4/5 176/176 [==============================] - 19s 107ms/step - loss: 0.2223 - acc: 0.9327 - val_loss: 0.2419 - val_acc: 0.9284 Epoch 5/5 176/176 [==============================] - 19s 106ms/step - loss: 0.1784 - acc: 0.9461 - val_loss: 0.2499 - val_acc: 0.9239
Code: To evaluate the test set
53/53 [==============================] - 5s 95ms/step The model achieved a loss of 0.23 and accuracy of 92.80%.
For a 101 class dataset, we have achieved a 92.8% accuracy after only 5 epochs. For perspective, the original ResNet was trained on an ~1 million image dataset, for 120 epochs.
There are a couple of things that could be improved upon. For one, looking at the discrepancy between validation loss and training loss in the last epoch, you can see that the model is starting to overfit. One way to solve this is to add image augmentation. Simple image augmentation can be easily implemented with the ImageDataGenerator class. You could also play around with adding/removing layers or changing hyperparameters such as the dropout or the size of the Dense layer.
Run this code here with Google Colab’s free GPU compute resources.
- Importance of Convolutional Neural Network | ML
- Applying Convolutional Neural Network on mnist dataset
- Capsule Neural Networks | ML
- ML | Introduction to Transfer Learning
- Recurrent Neural Networks Explanation
- Neural Networks | A beginners guide
- Neural Logic Reinforcement Learning - An Introduction
- DeepPose: Human Pose Estimation via Deep Neural Networks
- Learning Model Building in Scikit-learn : A Python Machine Learning Library
- Difference Between Artificial Intelligence vs Machine Learning vs Deep Learning
- Learning to learn Artificial Intelligence | An overview of Meta-Learning
- ML | Reinforcement Learning Algorithm : Python Implementation using Q-learning
- Introduction to Multi-Task Learning(MTL) for Deep Learning
- Artificial intelligence vs Machine Learning vs Deep Learning
- ML | Types of Learning – Supervised Learning
If you like GeeksforGeeks and would like to contribute, you can also write an article using contribute.geeksforgeeks.org or mail your article to firstname.lastname@example.org. See your article appearing on the GeeksforGeeks main page and help other Geeks.
Please Improve this article if you find anything incorrect by clicking on the "Improve Article" button below.