Serving ML model with Flask and Google Colab

Serving ML model with Flask and Google Colab


Hello friends, in this post we will learn to serve the input request with a simple CNN model. Let's check the depiction below, FLow_II.png                                     Image credit - AI and Machine Learning for Coders, by Laurence Moroney
It is a simple flow diagram of an ML model's lifecycle till the serving stage.
Just like any other software product/app, ML models too have an operations stage that begins after the model start serving to real data i.e. Production data. This is stage known as MLOps which we will not touch in this post.
Before the serving stage, we have all the trivial initial stages of an ML training cycle. This part is shown in the Translucent frame in the image above. We will skip that part too since we assume we all are fairly aware of all these steps.

Request, Response Flow

Let's check our Request-Response flow. If you have never worked on web API and don't know what is a Request, Response, then let me explain that to you. The query a user send to the server using his device/browser is a Request and the answer by the server is the Response.
Check this simple flow diagram, Arch.png We will not build any fancy webApp, so we will mimic the request-sending process using the Postman client on the Desktop. Postman is a mature app that is used to test web API. You can get that from Here
Our server will be hosted on Google Colb which is a cloud-based Notebook as a service.
Our model will be a CNN model trained on the MNIST digits dataset. Request data will be an image made on a digital whiteboard with a Hand-written digit using a stylus.

Let's Code

Our server code will have 4-key components,

  1. The model - This is the trained model ready to serve i.e. response to the predicted call
  2. Pre-processing function - You must be aware of the pre-processing step while training the model. Now we have to apply the same set of steps to the new data. The most important part of this step is to be aware of the statistics used in training because you can't calculate these statistics e.g. mean/std etc. every time for the training+new_test data.
    We will also place any generic utility here e.g. MNIST will be trained on b/w data but we may receive a coloured image in the request.
  3. Web API - This is our Flask based API. The responsibility of the API to accept the request. Perform the basic data validation, call the pre-processing/predict function and finally send the Response back to the Client.
  4. ngrok overhead - This component is not the part of our required components but we need this as a workaround to get a Public API out of Google Colab. So you can safely skip this part and use it simply as-it-is i.e. copy-paste.

Model and Flask Server

import numpy as np
from keras.models import load_model
from tensorflow import keras

# Load the model
model = load_model('/content/drive/MyDrive/Colab Notebooks/Blogs/10xAI_Blog_0022_ML-Serving/CNN_MNIST.h5')
x_train_max = 255.
x_train_min = 0.
path = "/content"

def pre_process(img_path):
    # load, B/W, Resize image
    img = keras.preprocessing.image.load_img(img_path, color_mode='grayscale', target_size=(28, 28))    
    img_arr = keras.preprocessing.image.img_to_array(img)

    # scale
    img_arr = (img_arr - x_train_min)/(x_train_max - x_train_min)
    img_arr = img_arr[np.newaxis,:,:]

    # predict and return
    return np.argmax(model.predict(img_arr), axis=-1)[0]

Image is loaded from the path in Grayscale format with a size of 28x28
All other parts of the code is quite trivial and self-explanatory.

With the above code snippet, we have created the training parameters, loaded the model and created the pre-processing function.
Let's build the Flask API

from flask import Flask, request
import time,os

app = Flask(__name__)

# Create a method for /
def home():
    return "<h1>Running Flask on Google Colab!</h1>"

# Post method for Predict
def predict_():

    # Get request param
    uploaded_file = request.files['file']  #1

    # Check it is a valid image file
    # Do it yourself

    # Save to DB/Disk
    img_path = os.path.join(path,uploaded_file.filename+time.strftime("%Y%m%d-%H%M%S"))
    img =  #2

    #Pre-processing and prediction
    digit_class = pre_process(img_path)  #3

    # Prepare output
    res = {"Digit": int(digit_class)}  #4

    return flask.Response(response=json.dumps(res), status=200, mimetype='application/json')

#1 - Till this line, we have use the standard code of Flask. Check the official docs Here
#2 - We should save our input for further analysis. For demo purpose we have just save to the disk.Concatenated the timestamp in the filename
#3 -Called the pre_processing function
#4 -Build the simple response JSON

Now we are left with creating a tunnel path using ngrok to make this API available publicly. We have use pyngrok for this. Post that we will start the Flask server.

# Need not know the detail of this code
# This bind the localhost:80 to an internet address and return the address
# Use that address to call the API from anywhere

!pip install pyngrok --quiet 
from pyngrok import ngrok
public_url = ngrok.connect(port="80", proto="http", options={"bind_tls": True},)
print("Public URL:", public_url)

# Start the server'', port=80, debug=False)

Result ngrok.PNG

Great!! We are ready with our server up and running.

Postman client

Let's go to the Postman client and call the API. Here is a screenshot from the Postman client. Red texts in the Green box are the explanation text. Postman.png

Select the method as "Post"
Put the URL that we got from the ngrok code. See the previuos image"
Click the Body tab and create a key with the name "file" and upload the image using the browse button
Click send and you will receive the Response.


We conclude this end to end code for ML serving. The best part of it is that you can do this in Google Colab. You can use this code and extend it to a very different use case i.e. Tabular data or you can try this for CAT/DOG dataset with your own model.
Just be mindful of the fact that real-life scenario will not work with just simplistic validation e.g. image must be saved in a file-hosting server and its metadata in a Database.Secondly, nothing special has been done to get a top-tier performance