Deploy YoloV3-PyTorch using Flask

Use Flask to deploy YoloV3-PyTorch

1. Project Introduction

This project is a small demo of web object detection, using Yolov3 (PyTorch) and Flask to perform object detection on the Web, involving target detection, Flask and Html
Yolov3 comes from Ultralytics, you can use their project to train a model that suits you

2. Overall project framework and code

Project address: https://github.com/BonesCat/Yolov3_flask

The main modifications are made to the code of Yolov3-Ultralytics, as follows:

  • 1. Modify the original detect.py to detect_for_flask.py to provide an interface for Flask
  • 2. All uploaded files will be time-renamed and saved to the “upload_files” folder
  • 3. The detected images will be saved to the “Output” folder

3. Quick start

  • Configure the environment according to the requirements in ult-yolov3 and install Flask by yourself. Note that you need to install and configure it in an evn environment.
  • Download or train a model, put the “.weights/.pt” file into the weights folder, configure the correct cfg, other configurations can be set on opt. This project can use the official weights provided by the original yolov3, just set the corresponding cfg That’s it.
  • Start serve.py, then enter “http://127.0.0.1:2222/upload” on the website, upload the image, and you will get the results and detection information.

4. Core code and simple explanation

  • Server.py
import time
import os
#Import the Flask class and request object in the flask library
from flask import Flask, request, flash, redirect, render_template, jsonify
from datetime import timedelta

#Import model related functions
from detect_for_flask import *


app = Flask(__name__)

#Set the saving location of uploaded files
UPLOAD_FOLDER = 'upload_files'
ALLOWED_EXTENSIONS = {<!-- -->'pdf', 'png', 'jpg', 'jpeg', 'gif'}

# Configure the path to app
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER

#Set static file cache expiration time
app.config['SEND_FILE_MAX_AGE_DEFAULT'] = timedelta(seconds=5) # timedalte is an object in datetime, which represents the difference between two times

print("SEND_FILE_MAX_AGE_DEFAULT:", app.config['SEND_FILE_MAX_AGE_DEFAULT'])

# Pre-initialize the model
model_inited, opt = init_model()

# Handle the validity of file names
def allow_filename(filename):
    return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS

@app.route('/upload', methods=['GET', 'POST']) # Add route

def upload():
    if request.method == 'POST':
        # If the uploaded file is not in files
        if 'file' not in request.files:
            # Flask message flash
            flash('not file part!')
            #Redisplay the current url page
            return redirect(request.url)

        '''
        The request object in the Flask framework stores all the information of an HTTP request.
        files records the files requested to be uploaded
        '''
        f = request.files['file']

        # Process empty files
        if f.filename == '':
            flash("Nothing file upload")
            return redirect(request.url)

        # The file is not empty and the format satisfies
        if f and allow_filename(f.filename):
            # Save uploaded files to local
            # Get the current time according to the format, from the named file
            now = time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime(time.time()))
            file_extension = f.filename.split('.')[-1]
            new_filename = now + '.' + file_extension
            file_path = './' + app.config['UPLOAD_FOLDER'] + '/' + new_filename
            f.save(file_path)

            # Make predictions and display pictures
            img, obj_infos = detect(model_inited, opt, file_path)
            return render_template('upload_ok.html', det_result = obj_infos)
    return render_template('upload.html')

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=2222)

detect_for_flask.py

import argparse
from sys import platform

from models import * # set ONNX_EXPORT in models.py
from utils.datasets import *
from utils.utils import *

'''
Based on detect.py in the original YoloV3, the detection function is rewritten to adapt to flask
'''


def init_model():
    '''
    Model parameter initialization
    : No input parameters
    :return: Complete the initial model and opt settings
    '''
    #paramentsconfig
    parser = argparse.ArgumentParser()
    parser.add_argument('--cfg', type=str, default='cfg/yolov3.cfg', help='*.cfg path')
    parser.add_argument('--names', type=str, default='data/coco.names', help='*.names path')
    parser.add_argument('--weights', type=str, default='weights/yolov3.weights', help='weights path')
    parser.add_argument('--output', type=str, default='output', help='output folder') # detect result will be saved here
    parser.add_argument('--img-size', type=int, default=416, help='inference size (pixels)')
    parser.add_argument('--conf-thres', type=float, default=0.3, help='object confidence threshold')
    parser.add_argument('--iou-thres', type=float, default=0.6, help='IOU threshold for NMS')
    parser.add_argument('--device', default='cpu', help='device id (i.e. 0 or 0,1) or cpu')
    parser.add_argument('--classes', nargs=' + ', type=int, help='filter by class')
    parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
    parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
    opt = parser.parse_args()
    print(opt)

    # init paraments
    out, weights, save_txt = opt.output, opt.weights, opt.save_txt

    #Initialize
    device = torch_utils.select_device(device='cpu' if ONNX_EXPORT else opt.device)
    if not os.path.exists(out):
        os.makedirs(out) # make new output folder

    #Initialize model
    model = Darknet(opt.cfg, opt.img_size)

    # Load weights
    attempt_download(weights)
    if weights.endswith('.pt'): # pytorch format
        model.load_state_dict(torch.load(weights, map_location=device)['model'])
    else: # darknet format
        load_darknet_weights(model, weights)

    return model, opt

def detect(model, opt, image_path):
    '''
    :param model: The initialized model
    :param opt: opt parameter
    :param image_path: passed in image address
    :param save_img: whether to save the image
    :return: The result after positioning is completed
    '''
    # Eval mode
    model.to(opt.device).eval()
    # Save img?
    save_img = True

    # Process the upload image

    # read img
    img0 = cv2.imread(image_path) # BGR
    assert img0 is not None, 'Image Not Found ' + image_path

    # Padded resize
    img = letterbox(img0, new_shape=opt.img_size)[0]

    #Convert
    img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
    img = np.ascontiguousarray(img)

    # Get names and colors
    names = load_classes(opt.names)
    colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(len(names))]

    # Run inference
    t0 = time.time()

    img = torch.from_numpy(img).to(opt.device)
    img = img.float() # uint8 to fp16/32
    img /= 255.0 # 0 - 255 to 0.0 - 1.0
    if img.ndimension() == 3:
        img = img.unsqueeze(0)
    with torch.no_grad():
        #Inference
        t1 = torch_utils.time_synchronized()
        pred = model(img)[0]
        t2 = torch_utils.time_synchronized()
        # print("pred:", pred)

        # Apply NMS
        pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)

        #Process detections
        for i, det in enumerate(pred): # detections per image
            #This is for all detected objects. The detection result is a two-dimensional list.
            # Each row stores the upper left and lower right coordinates of an obj, confidence level, and category
            # print("det", det)

            p, s = image_path, ''

            save_path = str(Path(opt.output) / Path(p).name)
            s + = '%gx%g ' % img.shape[2:] # print string
            # If the object is detected, the list is not empty
            if det is not None and len(det):
                # Rescale boxes from img_size to im0 size
                det[:, :4] = scale_coords(img.shape[2:], det[:, :4], img0.shape).round()

                # Print results
                for c in det[:, -1].unique():
                    n = (det[:, -1] == c).sum() # detections per class
                    s + = '%g %ss, ' % (n, names[int(c)]) # add to string
                # Set up a dictionary and write each target data
                obj_info_list = []
                # Traverse each row in the two-dimensional det to process each obj
                # Write results
                for *xyxy, conf, cls in det:
                    if opt.save_txt: # Write to file
                        with open(save_path + '.txt', 'a') as file:
                            file.write(('%g ' * 6 + '\\
') % (*xyxy, cls, conf))

                    if save_img: # Add bbox to image
                        label = '%s %.2f' % (names[int(cls)], conf)
                        plot_one_box(xyxy, img0, label=label, color=colors[int(cls)]) # The parameter xyxy contains the coordinates of bbox
                    # Record the coordinates, category, and confidence of a single target
                    sig_obj_info =('%s %g %g %g %g %g' ) % (names[int(cls)], *xyxy, conf)
                    print("sig_obj_info:", sig_obj_info)
                    obj_info_list.append(sig_obj_info)

            # Print time (inference + NMS)
            print('%sDone. (%.3fs)' % (s, t2 - t1))


            # Save results (image with detections)
            if save_img:
                # Save twice
                # 1. Save the test results permanently and store them in the output folder.
                cv2.imwrite(save_path, img0)
                # 2. Temporary file for display
                cv2.imwrite('./static/temp.jpg', img0)

    print('Done. (%.3fs)' % (time.time() - t0))
    return img0, obj_info_list


if __name__ == '__main__':
    img_path = './data/samples/timg1.jpg'
    model_inited, opt = init_model()
    result,obj_infos = detect(model = model_inited, opt = opt, image_path=img_path)
    print(obj_infos)

5. Project screenshots


6. References and Acknowledgments

https://github.com/ultralytics/yolov3
https://blog.csdn.net/rain2211/article/details/105965313/

Note: It is just a simple demo, and there is no writing about the processing when it is not detected. You can handle the error yourself.