Teach you how to use pytorch + flask to build a sketch retrieval system (2)
Article directory
- Teach you to use pytorch + flask to build a sketch retrieval system (2)
-
- 1. Summary and notice
- Two. Flask builds the server
-
- 2.1 Controller implementation
- 2.2 Add model loading
- 2.3 Backend retrieval
- 3. Summary
1. Summary and preview
- Preparations -> Links
- Back-end construction -> content of this article
- Front-end construction
- front-end interaction
- demo
The previous article “Teaching you to use pytorch + flask to build a sketch retrieval system (1)” introduced the preparations for building the system, the retrieval principle of this system, and the feature extraction process of the twin network SketchTriplet. The pre-trained model, the code used by the model, and the data set have been posted at the end of the previous section. You can download it from Google Drive or Baidu Cloud. This article will introduce the back-end construction process of the retrieval system, including using flask to build a server, and using pytorch to load the model. It has been going on until now, one is studies and the other is work. I will try my best to update it gradually in the near future. I would like to say sorry to those students who have been waiting for half a year.
2. Flask builds the server
There are a lot of online tutorials. What I read is the official Chinese document of this FLASK. It is written in extremely detailed and easy to understand, from zero to one. In order to get started quickly, I will briefly explain my understanding here. Flask is a bit like MVC (Module , View, Controller) set, it has two fixed folders: static
, templates
, static
stores static web pages Loaded resources (such as css, image, json, etc.), so for the image img_0.png
used by a static endpoint static
, its path should be:
url_for('static', filename='img_0.png')
And its location in the file system should be static/img_0.png
, if we want to access this image, we should type 127.0.0.1/img_0.png
in the browser . There can be multiple web pages in the static endpoint, which are stored in templates
, so this templates
is a bit like V in MVC, while static
is M in MVC, and the python file in the root directory is C in MVC. In other words, in an APP, there should be the following file structure:
- base_folder - static - img - css ... - templates - 0.html - 1.html ... - controller.py
Note, static
, templates
may not necessarily be called this name absolutely. When flask is initialized, you can modify it at will. The modification method is as follows:
app = Flask(__name__, template_folder='templates', static_folder='static')
2.1 Controller Implementation
First of all, for the hand-drawn retrieval system, according to the process: first, there must be an interface to draw a sketch, which is front-end work; then save the drawn sketch and upload it to the back-end, which is Front-end and front-end interaction; then, the back-end searches according to the uploaded sketches and gets the search results, which is back-end work; then returns the search results to the front-end, which is before and after End-to-end interaction; finally, the search results are displayed, which is front-end work. In other words, we need two front-ends, two interactions, and one back-end.
According to the above ideas, we first leave a pit for the drawing interface, create a new controller.py
in the root directory, and create a new flask route according to the following code, so that the server can run normally and access:
from flask import Flask, render_template from datetime import timedelta # New APP app = Flask(__name__, template_folder='templates', static_folder='static') # Set static file cache expiration time app.send_file_max_age_default = timedelta(seconds=1) # Create a new route, leaving a page for the drawing interface @app.route('/') def hello(): return render_template('canva.html') if __name__ == '__main__': # app.debug = True app.run(debug=True)
Then in the ./templates
folder, create a new canva.html
static web page, as shown below:
<!DOCTYPE html> <html lang="en"> <head> <meta charset="UTF-8"> <h1>HELLO WORLD! I'm canvas</h1> </head> <body> </body> </html>
Ok, run it here and try the effect: run it directly in pycharm, or you can also python controller.py
, follow the output * Running on http://127.0.0.1: 5000/ (Press CTRL + C to quit)
, you can see the drawing interface, as shown in the figure below:
2.2 Add model loading
Continue to add code inside controller.py
and add the upload
function, as shown below:
from flask import Flask, render_template, request from datetime import timedelta from scipy.misc import imsave import json, os, time, base64 # Load the corresponding package from SketchTriplet.SketchTriplet_half_sharing import BranchNet from SketchTriplet.SketchTriplet_half_sharing import SketchTriplet as SketchTriplet_hs from SketchTriplet.flickr15k_dataset import flickr15k_dataset_lite from SketchTriplet.retrieval import retrieval # define load function def load_model_retrieval(): # model relative path net_dict_path = '../SketchTriplet/500.pth' branch_net = BranchNet() # for photography edge net = SketchTriplet_hs(branch_net) net.load_state_dict(torch.load(net_dict_path)) net = net.cuda() net.eval() return net #---------------------------------------- # Load the flickr15k dataset flickr15k_dataset = flickr15k_dataset_lite() # Load the retrieved model retrieval_net = load_model_retrieval() #---------------------------------------- # New APP app = Flask(__name__, template_folder='templates', static_folder='static') # Set static file cache expiration time app.send_file_max_age_default = timedelta(seconds=1) # Create a new route, leaving a page for the drawing interface @app.route('/') def hello(): return render_template('canva.html') # Create a new upload route, because of the transmission behavior, all add POST, GET methods @app.route('/upload', methods=['POST', 'GET']) def upload(): if request.method == 'POST': # Get the uploaded sketch sketch_src = request. form. get("sketchUpload") # Get the flag of successful upload, either mouse drawing or local upload upload_flag = request. form. get("uploadFlag") # If the upload fails, return to the upload page sketch_src_2 = None if upload_flag: sketch_src_2 = request.files["uploadSketch"] if sketch_src: flag = 1 elif sketch_src_2: flag = 2 else: return render_template('upload.html') # Process uploaded sketches basepath = os.path.dirname(__file__) upload_path = os.path.join(basepath, 'static/sketch_tmp', 'upload.png') if flag == 1: # mouse draw sketch = base64.b64decode(sketch_src[22:]) user_input = request. form. get("name") file = open(upload_path,"wb") file.write(sketch) file. close() elif flag == 2: # Local Upload sketch_src_2. save(upload_path) user_input = request. form. get("name") # start searching retrieval_list, real_path = retrieval(retrieval_net, upload_path, flickr15k_dataset) # Wrap the returned path in json form real_path = json. dumps(real_path) # After successful retrieval, render the result return render_template('retrieval.html', userinput=user_input, val1=time.time(), upload_src=sketch_src, retrieval_list = retrieval_list, json_info = real_path) # Other operations, return to the upload page return render_template('upload.html')
First of all, in controller.py
, we need to load the model and dataset into the system memory first, this is flickr15k_dataset_lite()
and load_model_retrieval()
works.
Next, I hope that in the canva.html
drawing interface, there are two ways to upload the sketch, one is to draw the sketch directly with the mouse, and the other is to upload the sketch locally, so in the code, there are two ways to upload Behavior.
Then, we need to save the uploaded image. For the first method of hand-drawing with the mouse, the drawing interface will encode the drawing result in the form of Base64, so I use base64.b64decode
to return the result Decode and save it to static/sketch_tmp/upload.png
. For the second local upload method, here I am lazy and directly copy the local, which is actually wrong. If you look closely at the code, there is the behavior of front-end and back-end interaction here.
Then, the backend gets the uploaded image, retrieves it using retrieval
, and returns the retrieval path. Here I use json to wrap the path.
Finally, the packaged retrieval results and the uploaded hand-painted images are handed over to the retrieval.html
retrieval results page for rendering, and the retrieval results and input are displayed on the retrieval results page. This is the operation process of the entire backend.
2.3 Backend search
The retrieval
method mentioned in controller.py
, the code is as follows:
from PIL import Image import numpy as np def retrieval(net, sketch_path, dataset): # Open the input image with PIL sketch_src = Image.open(sketch_path).convert('RGB') # Extract the features of the hand-drawn sketch feat_s = extract_feat_sketch(net, sketch_src) # Read in the feature set of natural image flickr15k feat_photo_path = '../SketchTriplet/feat.npz' feat_photo = np. load(feat_photo_path) # parse file content feat_p = feat_photo['feat'] cls_name_p = feat_photo['cls_name'] cls_num_p = feat_photo['cls_num'] path_p = feat_photo['path'] name_p = feat_photo['name'] # L2 calculates feature distance dist_l2 = np.sqrt(np.sum(np.square(feat_s - feat_p), 1)) # sort order = np.argsort(dist_l2) # Return relative paths in order order_path_p = path_p[order] # return the absolute path return get_real_path(order_path_p)
Three. Summary
This article explains how to use flask to build the backend, and briefly explains the process of the retrieval system from uploading hand-drawn sketches to calculating retrieval results. The above is the whole content of this article. Later, we will introduce the front-end construction process, as well as the front-end and back-end interaction process, demo display, so stay tuned.