Use java-ort to deploy yolovx series

The process is to load the weight, input the picture, resize the picture to the shape input by the model, and divide by 255 to normalize. The pixels need to be arranged in the order of chw, and then input to the model for prediction. After the prediction is completed, analyze 25200*85, according to the confidence, The thresholds such as nms are used to filter, and then the coordinates of the remaining boxes are converted to the original image coordinate system according to the scaling ratio, and finally marked.

You can see that the model outputs 1*25200*85, which means that 25200 frames are scanned for the picture, and each frame uses 85 floating-point data to save the frame graph confidence, center point coordinates, width and height, and the probability of each category, etc. You can do it yourself parse. Then use com.microsoft.onnxruntime provided by Microsoft to load and reason:

Open Neural Network Exchange ONNX (Open Neural Network Exchange) is an open format for expressing deep neural network models. It was launched by Microsoft and Facebook in 2017, and then quickly gained support from major manufacturers and frameworks. After just a few years of development, it has become the actual standard for expressing deep learning models, and through ONNX-ML, it can support traditional non-neural network machine learning models, which greatly unifies the entire AI model exchange standard. ONNX defines a set of environment- and platform-independent standard formats, providing a basis for the interoperability of AI models, enabling AI models to be used interactively in different frameworks and environments. Hardware and software manufacturers can optimize model performance based on the ONNX standard, benefiting all frameworks compatible with the ONNX standard. Simply put, ONNX is the middleman for model conversion.

Use the following script to convert the model to onnx:

python export.py –weights C:\Users\tangyufan\Desktop\custom\res\custom-01\weights\best.pt –include torchscript onnx –opset 16

Notice:

Here use –include torchscript onnx to generate the onnx file, and specify –opset 16, and then use the com.microsoft.onnxruntime Java library to load dependencies and require opset not to exceed 16.

Each ONNX version supports a different set of operators, so
The value of opset_version affects which PyTorch operators can be exported to ONNX format. If an operator that is not supported in the ONNX operator set is used in the model, the model will not be exported. In addition, the exported model can only be run on platforms that support the corresponding ONNX version.

The complete code is as follows:

package tool.yolo.onnxruntime;

import ai.onnxruntime.*;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import org.bytedeco.javacv.CanvasFrame;
import org.bytedeco.javacv.OpenCVFrameConverter;
import org.opencv.core.*;
import org.opencv.core.Point;
import org.opencv.imgcodecs.Imgcodecs;
import org.opencv.imgproc.Imgproc;
import javax.imageio.ImageIO;
import javax.swing.*;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.File;
import java.nio.FloatBuffer;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.*;

/**
* @desc : Use com.microsoft.onnxruntime to load yolov5 onnx for inference
* @auth: tyf
* @date : 2023-03-21 09:31:31
*/
public class predictTest {
    // onnxruntime environment
    public static OrtEnvironment env;
    public static OrtSession session;
    // The category information of the model, read from the weight
    public static JSONObject names;
    // The input shape of the model, read from the weight
    public static long count;//1 model processes one image at a time
    public static long channels;//3 model channel number
    public static long netHeight;//640 model height
    public static long netWidth;//640 model width
    // Detection box screening threshold, refer to the settings in detect.py
    public static float confThreshold = 0.25f;
    public static float nmsThreshold = 0.45f;
    // onnxruntime environment initialization
    static {
        // The current com.microsoft.onnxruntime library only supports opset<=16 when exporting weights, so you need to set --opset 16 when exporting the model
        String weight = "C:\Users\tyf\Desktop\yolov5s.onnx";
        try {
            env = OrtEnvironment. getEnvironment();
            session = env.createSession(weight, new OrtSession.SessionOptions());
            // Save some model information such as input width, height, category, etc.
            // 3. Print the model, getCustomMetadata contains category information, model input width and height, etc.
            OnnxModelMetadata metadata = session. getMetadata();
            Map<String, NodeInfo> infoMap = session. getInputInfo();
            TensorInfo nodeInfo = (TensorInfo) infoMap. get("images"). getInfo();
            String nameClass = metadata. getCustomMetadata(). get("names");
            System.out.println("-------print model information start-------");
            System.out.println("getProducerName=" + metadata.getProducerName());
            System.out.println("getGraphName=" + metadata.getGraphName());
            System.out.println("getDescription=" + metadata.getDescription());
            System.out.println("getDomain=" + metadata.getDomain());
            System.out.println("getVersion=" + metadata.getVersion());
            System.out.println("getCustomMetadata=" + metadata.getCustomMetadata());
            System.out.println("getInputInfo=" + infoMap);
            System.out.println("nodeInfo=" + nodeInfo);
            System.out.println("-------End of printing model information-------");
            // 4. Read category information from it {0: 'person', 1: 'bicycle', 2: 'car'}
            names = JSONObject. parseObject(nameClass. replace(""",""""));
            System.out.println("Category information:" + names);
            // 5. It is necessary to read the shape of the input tensor from the model information. Any image needs to be converted to this size before it can be input into the model, and the coordinates of the detection frame obtained by the model output need to be transformed back. Yolov5 is 640*640
            count = nodeInfo.getShape()[0];//1 model processes one image at a time
            channels = nodeInfo.getShape()[1];//3 model channel number
            netHeight = nodeInfo.getShape()[2];//640 model height
            netWidth = nodeInfo.getShape()[3];//640 model width
            System.out.println("Model channel number=" + channels + ", network input height = " + netHeight + ", network input width = " + netWidth);
            // opencv library, copy opencv\build\java\x64\opencv_java455.dll to the bin directory of Java JDK installation
            // Obtained from the dependencies of org.openpnp.opencv
            System.loadLibrary(Core.NATIVE_LIBRARY_NAME);

        }
        catch (Exception e){
            e.printStackTrace();
            System. exit(0);
        }
    }
    // Use opencv to read pictures to mat
    public static Mat readImg(String path){
        Mat img = Imgcodecs.imread(path);
        return img;
    }

    // store a src_mat modified size in dst_mat
    public static Mat resizeWithPadding(Mat src) {
        Mat dst = new Mat();
        int oldW = src. width();
        int oldH = src. height();
        double r = Math.min((double) netWidth / oldW, (double) netHeight / oldH);
        int newUnpadW = (int) Math. round(oldW * r);
        int newUnpadH = (int) Math. round(oldH * r);
        int dw = (Long. valueOf(netWidth). intValue() - newUnpadW) / 2;
        int dh = (Long. valueOf(netHeight). intValue() - newUnpadH) / 2;
        int top = (int) Math. round(dh - 0.1);
        int bottom = (int) Math. round(dh + 0.1);
        int left = (int) Math. round(dw - 0.1);
        int right = (int) Math. round(dw + 0.1);
        Imgproc.resize(src, dst, new Size(newUnpadW, newUnpadH));
        Core.copyMakeBorder(dst, dst, top, bottom, left, right, Core.BORDER_CONSTANT);
        return dst;
    }
    // Convert image matrix to tensor required by onnxruntime
    // According to the preprocessing of yolo's input tensor, normalization, BGR -> RGB, etc. are required. For details, see the detect.py script
    public static OnnxTensor transferTensor(Mat dst){
        // BGR -> RGB
        Imgproc.cvtColor(dst, dst, Imgproc.COLOR_BGR2RGB);
        // Normalize 0-255 to 0-1
        dst.convertTo(dst, CvType.CV_32FC1, 1. / 255);
        // Initialize an input array channels * netWidth * netHeight
        float[] whc = new float[ Long.valueOf(channels).intValue() * Long.valueOf(netWidth).intValue() * Long.valueOf(netHeight).intValue() ];
        dst.get(0, 0, whc);
        // Get the final image and turn it into a float array
        float[] chw = whc2cwh(whc);
        // Create the tensor needed by onnxruntime
        // Pass in the input image float array and specify the array shape
        OnnxTensor tensor = null;
        try {
            tensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(chw), new long[]{count,channels,netWidth,netHeight});
        }
        catch (Exception e){
            e.printStackTrace();
            System. exit(0);
        }
        return tensor;
    }

    public static float[] whc2cwh(float[] src) {
        float[] chw = new float[src.length];
        int j = 0;
        for (int ch = 0; ch < 3; + + ch) {
            for (int i = ch; i < src. length; i + = 3) {
                chw[j] = src[i];
                j + + ;
            }
        }
        return chw;
    }

    // Get the subscript of the maximum value in the array, and find the category with the highest probability among the 80 categories
    public static int getMaxIndex(float[] array) {
        int maxIndex = 0;
        float maxVal = array[0];
        for (int i = 1; i < array. length; i ++ ) {
            if (array[i] > maxVal) {
                maxVal = array[i];
                maxIndex = i;
            }
        }
        return maxIndex;
    }


    // Convert center point coordinates to xin xmax ymin ymax
    public static float[] xywh2xyxy(float[] bbox) {
        // center point coordinates
        float x = bbox[0];
        float y = bbox[1];
        float w = bbox[2];
        float h = bbox[3];
        // calculate
        float x1 = x - w * 0.5f;
        float y1 = y - h * 0.5f;
        float x2 = x + w * 0.5f;
        float y2 = y + h * 0.5f;
        // limited to the image area
        return new float[]{
                x1 < 0 ? 0 : x1,
                y1 < 0 ? 0 : y1,
                x2 > netWidth ? netWidth: x2,
                y2 > netHeight? netHeight:y2};
    }

    // Process the model output 25200*85, according to the confidence threshold
    public static JSONArray filterRec1(float[][] data){
        JSONArray recList = new JSONArray();
        // traverse 25200 block diagrams
        // traverse 25200 detection boxes
        // 25200 = 80 * 80 + 40 * 40 + 20 * 20
        // That is, divide according to the macroblocks of 8, 16, and 32 pixels
        for (float[] bbox : data){
            // Each detection frame uses an 85-bit array to save information such as center point, confidence degree, category probability, etc., as follows:
            // The four positions from 0 to 3 are x y w h positioning information (center point coordinates and width and height) that need to be converted to xyxy, that is, the coordinates of the upper left corner and the lower right corner, and are limited to the maximum width and height of the image
            float[] xywh = new float[] {bbox[0],bbox[1],bbox[2],bbox[3]};
            float[] xyxy = xywh2xyxy(xywh);
            // 4 This position represents the confidence score of the detection frame
            float confidence = bbox[4];
            // 5~85 is the probability score of all categories (80), need to find the maximum value and its index
            float[] classInfo = Arrays. copyOfRange(bbox, 5, 85);
            int maxIndex = getMaxIndex(classInfo);// the index of the class with the highest probability
            float maxValue = classInfo[maxIndex];// The probability of the class with the highest probability
            String maxClass = (String)names.get(Integer.valueOf(maxIndex));// The label of the category with the highest probability
            // First, rough selection based on the confidence of the block diagram
            if(confidence>=confThreshold){
                // The way of filtering is to remove the low confidence according to the confidence threshold of the block diagram, and the rest is the block diagram with high confidence
                // The rest of the block diagrams are basically repeated block diagrams, that is, one target can be selected multiple times, and the inner product is calculated according to nms to filter. Algorithm reference detect.py
                JSONObject detect = new JSONObject();
                detect.put("name",maxClass);// category
                detect.put("percentage",maxValue);// probability
                detect.put("xmin",xyxy[0]);
                detect.put("ymin",xyxy[1]);
                detect.put("xmax",xyxy[2]);
                detect.put("ymax", xyxy[3]);
                recList. add(detect);
            }
        }
        return recList;
    }

    // Then perform nms filtering based on category overlap
    public static JSONArray filterRec2(JSONArray data){
        // save the result
        JSONArray res = new JSONArray();
        // Sort by probability from high to low
        data.sort(Comparator.comparing(obj->((JSONObject)obj).getString("percentage")).reversed());
        // Perform nms processing
        while (!data.isEmpty()){
            JSONObject max = data.getJSONObject(0);// Take the detection frame with the highest probability each time and save it to the result directory
            res. add(max);
            Iterator<Object> it = data. iterator();
            // Calculate the iou of this detection frame and all other detection frames. If it exceeds the threshold, that is, the overlap is too large, remove it from the original set
            while (it. hasNext()) {
                JSONObject obj = (JSONObject)it. next();
                double iou = calculateIoU(max, obj);
                if (iou > nmsThreshold) {
                    it. remove();
                }
            }
        }
        return res;
    }

    // Calculate the intersection and union ratio of two boxes
    private static double calculateIoU(JSONObject box1, JSONObject box2) {
        double x1 = Math.max(box1.getDouble("xmin"), box2.getDouble("xmin"));
        double y1 = Math.max(box1.getDouble("ymin"), box2.getDouble("ymin"));
        double x2 = Math.min(box1.getDouble("xmax"), box2.getDouble("xmax"));
        double y2 = Math.min(box1.getDouble("ymax"), box2.getDouble("ymax"));
        double intersectionArea = Math.max(0, x2 - x1 + 1) * Math.max(0, y2 - y1 + 1);
        double box1Area = (box1. getDouble("xmax") - box1. getDouble("xmin") + 1) * (box1. getDouble("ymax") - box1. getDouble("ymin") + 1);
        double box2Area = (box2. getDouble("xmax") - box2. getDouble("xmin") + 1) * (box2. getDouble("ymax") - box2. getDouble("ymin") + 1);
        double unionArea = box1Area + box2Area - intersectionArea;
        return intersectionArea / unionArea;
    }

    // Convert the coordinates of two points output by the network to the coordinates of the original image Determine the zoom ratio according to the original width and height and the width and height of the network input
    // xmin, ymin, xmax, ymax -> (xmin_org, ymin_org, xmax_org, ymax_org)
    public static JSONArray transferSrc2Dst(JSONArray data,int srcw,int srch){
        JSONArray res = new JSONArray();
        System.out.println("-------coordinate conversion--------");
        /*
        Specifically, srcw and srch in the code represent the width and height of the original image respectively,
        gain is the scaling ratio, calculated as the minimum of the scaling ratio required to scale the original image to the specified size and the scaling ratio required to scale the original image to its aspect ratio.
        padW and padH respectively represent the size of the blank space in the horizontal and vertical directions, and the calculation method is to subtract the specified size from the size of the image after scaling (that is, the size before scaling multiplied by the scaling factor) and divide by 2.
        The function of this code is to calculate the size of the blank space, so that after the original image is scaled to the specified size, the scaled image has the same aspect ratio as the specified size, and the blank size can also be used to convert the scaled image Placed at the correct position on a canvas of the specified size.
        */
        float gain = Math.min((float) netWidth / srcw, (float) netHeight / srch);
        float padW = (netWidth - srcw * gain) * 0.5f;
        float padH = (netHeight - srch * gain) * 0.5f;
        data.stream().forEach(n->{
            JSONObject obj = JSONObject. parseObject(n. toString());
            float xmin = obj. getFloat("xmin");
            float ymin = obj. getFloat("ymin");
            float xmax = obj. getFloat("xmax");
            float ymax = obj. getFloat("ymax");
            // scaled coordinates
            float xmin_ = Math.max(0, Math.min(srcw - 1, (xmin - padW) / gain));
            float ymin_ = Math.max(0, Math.min(srch - 1, (ymin - padH) / gain));
            float xmax_ = Math.max(0, Math.min(srcw - 1, (xmax - padW) / gain));
            float ymax_ = Math.max(0, Math.min(srch - 1, (ymax - padH) / gain));
            obj.put("xmin",xmin_);
            obj.put("ymin",ymin_);
            obj.put("xmax",xmax_);
            obj.put("ymax",ymax_);
            System.out.println("net output coordinates: (" + xmin + "," + ymin + ")" + ", converted coordinates: (" + xmax_ + "," + ymax_ + ")");
            res.add(obj);
        });
        return res;
    }
    // Mark the frame on the original picture and display it in a pop-up window
    public static void pointBox(String pic,JSONArray box){
        if(box. size()==0){
            System.out.println("There is no recognition target");
            return;
        }
        try {
            // picture
            File imageFile = new File(pic);
            BufferedImage img = ImageIO. read(imageFile);
            Graphics2D graph = img. createGraphics();
            graph.setStroke(new BasicStroke(2));//line thickness
            graph.setFont(new Font("Serif", Font.BOLD, 20));// text
            graph.setColor(Color.RED);
            // frame information
            box. stream(). forEach(n->{
                JSONObject obj = JSONObject. parseObject(n. toString());
                String name = obj. getString("name");
                float percentage = obj.getFloat("percentage");// The probability is converted to two decimal places
                float xmin = obj. getFloat("xmin");
                float ymin = obj. getFloat("ymin");
                float xmax = obj. getFloat("xmax");
                float ymax = obj. getFloat("ymax");
                float w = xmax - xmin;
                float h = ymax - ymin;
                // draw rectangle
                graph. drawRect(
                        Float.valueOf(xmin).intValue(), //Pay attention to the coordinate system, the point in the upper left corner should be xmin ymin
                        Float.valueOf(ymin).intValue(),
                        Float. valueOf(w). intValue(),
                        Float. valueOf(h). intValue());
                // draw categories and probabilities
                DecimalFormat decimalFormat = new DecimalFormat("#.##");
                String percentString = decimalFormat. format(percentage);
                graph.drawString(name + " " + percentString, xmin-1, ymin-5);
            });
            // Submit frame
            graph.dispose();
            // popup display
            JFrame frame = new JFrame("Image Dialog");
            frame.setSize(img.getWidth(), img.getHeight());
            JLabel label = new JLabel(new ImageIcon(img));
            frame.getContentPane().add(label);
            frame. setVisible(true);
            frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
        }
        catch (Exception e){
            e.printStackTrace();
            System. exit(0);
        }
    }
    public static void main(String[] args) throws Exception{
        // Read the picture and save the original width and height
        String pic = "C:\Users\tyf\Desktop\img.png";
        Mat src = readImg(pic);
        int srcw = src. width();
        int srch = src. height();
        // Rewrite and modify the width and height of the network input
        Mat dst = resizeWithPadding(src);
        // Preprocess the input image and convert it to tensor According to the preprocessing of yolo's input tensor, it needs to perform normalization, BGR -> RGB and other super-doing. You can see the detect.py script for details
        OnnxTensor tensor = transferTensor(dst);
        // do inference
        OrtSession.Result result = session.run(Collections.singletonMap("images", tensor));
        // Obtain the model output which is a matrix of 1*25200*85
        OnnxTensor res = (OnnxTensor) result. get(0);
        float[][][] dataRes = (float[][][])res. getValue();
        // Take 25200*85 matrix
        // This is the final output of the yolov5 model, which contains 25,200 detection frames, and each detection frame uses an array of 85 to record the detection frame information
        float[][] data = dataRes[0];
        // Roughly filter according to the confidence of the block diagram
        JSONArray srcRec = filterRec1(data);// After filtering, each json saves the category, probability, upper left, lower right coordinates of the target
        // Remove the duplicate box according to nms
        JSONArray srcRec2 = filterRec2(srcRec);// After filtering, each json saves the target's category, probability, upper left and right lower coordinates
        // Convert the coordinates of two points output by the network to the coordinates of the original image Determine the zoom ratio according to the original width and height and the width and height of the network input
        // xmin, ymin, xmax, ymax -> (xmin_org, ymin_org, xmax_org, ymax_org)
        JSONArray dstRec = transferSrc2Dst(srcRec2,srcw,srch);
        // Label the frame diagram and category information of the original picture, and the pop-up window will display
        pointBox(pic,dstRec);
    }
}

// There are actually two dependencies, the former can only reason with cpu, the latter can reason with cpu or gpu

//

// com.microsoft.onnxruntime

// onnxruntime_gpu

// 1.11.0

//

//

// com.microsoft.onnxruntime

// onnxruntime_gpu

// 1.11.0

//

Set up the GPU in the following way:

int gpuDeviceId = 0; // The GPU device ID to execute on

var sessionOptions = new OrtSession. SessionOptions();

sessionOptions.addCUDA(gpuDeviceId);

var session = environment. createSession(“model. onnx”, sessionOptions);

Among them, deviceId is queried through cuda script, here is 0:

The knowledge points of the article match the official knowledge files, and you can further learn relevant knowledge