Deploying a TensorFlow 2.1 CNN model on the web with Flask
When starting to learn Machine Learning, one of the biggest issues people face is deploying whatever they make to the web for easier demonstration/use. This article will help those beginners bridge the gap between creating a TensorFlow model and deploying it on the web with Flask and hopefully gain some insight on the issues TensorFlow and Flask have.
Creating a Model
First we will have to train our model using TensorFlow and Keras’s ImageDataGenerator. For this, We have downloaded 12500 images of Cat’s and 12500 images of dogs respectively. To train a model using ImageDataGenerator first install
pip install keras
pip install tensorflow
We have trained a model using transfer learning from
InceptionV3 model. For a tutorial on transfer learning visit this link. After the model has been trained, we will need to execute the following command which creates a model folder of the trained model. (Before TensorFlow 2.0, a model file was created instead of a model folder).
Creating Flask Application
Then, we will have to install Flask. Flask is a micro-framework in Python which is extensively used to deploy ML models on the web, we will have to install flask using the following command.
pip install flask
Coming to the webpage we want to build. Create a HTML page according to your preference, we have created a basic webpage with a button to upload an image and a button to predict what the image is. This code sends the image as a post request and stores it in the folder “uploaded/image/” where a prediction is made from here. Courtesy of Shubham S. Naik from Stack Overflow. The code is divided into 2 files, namely a .js file an a .html file. Copy the below code into a file named
upload.html and save it in a folder named templates.
Copy the below code into a file named
upload.js and save it in a folder called static.
After this is done, we will need
ImageDataGenerator to make a prediction of an image once it has been trained. First we will load the model using
Then, we will need to import ImageDataGenerator. Now, coming to the flask the flask code. Copy the below code into a
Code.py file and place it in the root directory. The complete Python code is as follows:-
Finally we need to render a html page for the final prediction. Copy the below code and place save it as
pred.html in templates folder.
The final folder structure is described as follows :-
The black colored flags are folders and blue flags are files.
Finally, all we need to do is run the file code.py and go to http://127.0.0.1:5000/ to see the output
Note : Every time a prediction is made, please delete the images in uploaded/image before uploading a new one or you might get an Internal Server error.