Based on the YOLOv5n/s/m model with different parameter magnitudes, the tea bud detection and recognition model was developed, and the pruning pruning technology was used to lighten the model, and the performance impact of the model under different pruning levels was explored.

Today, I have some time to think about a problem left over from before. I just took it and took a look. The main purpose is to prune the trained target detection model. Here we take the tea bud detection data scene as an example. I have already introduced related practices in my previous blog post. If you are interested, you can read it by yourself:

“Integrated CBAM Attention Mechanism Based on YOLOv5 Development and Construction of Maojian Tea Shoots Detection and Recognition System”

I won’t go into details here.

In this paper, three models with different magnitudes of n/s/m are selected to build the training model in turn, and all parameters are kept at the same settings, and then the performance impact under different pruning operations is explored.

Just look at the data set:

The training instructions for the three models are as follows:

#yolov5n
python3 train.py --cfg models/yolov5n.yaml --weights weights/yolov5n.pt --name yolov5n --epochs 100 --batch-size 4 --img-size 416

#yolov5s
python3 train.py --cfg models/yolov5s.yaml --weights weights/yolov5s.pt --name yolov5s --epochs 100 --batch-size 4 --img-size 416

#yolov5m
python3 train.py --cfg models/yolov5m.yaml --weights weights/yolov5m.pt --name yolov5m --epochs 100 --batch-size 4 --img-size 416

There are two main points. One is that the batchsize is set relatively small here because three models are running at the same time, and the setting here is 4; the other is imgsize. In order to speed up the experiment rhythm, the setting is 416, which is a relatively low resolution. instead of 640.

By default, the iterative calculation of 100 epochs is performed. Next, let’s look at the actual training situation in turn:
【yolov5n】

【yolov5s】

【yolov5m】

Judging from the evaluation results of the final model: the results of the s-series models are not as good as those of the n-series models, or there is little difference between the two, and the results of the m-series models are better than the other two models.

In order to be able to intuitively compare and analyze the three models with different parameter levels, the main indicators are visualized here, as follows:

【Precision Curve】
The Precision-Recall Curve is a visualization tool for evaluating the precision performance of a binary classification model at different thresholds. It helps us understand how the model performs at different thresholds by plotting the relationship between precision and recall at different thresholds.
Precision refers to the ratio of the number of samples correctly predicted as positive to all the samples predicted as positive. The recall rate (Recall) refers to the proportion of the number of samples that are correctly predicted as positive examples to the number of samples that are actually positive examples.
The steps to draw the accuracy curve are as follows:
Convert predicted probabilities to binary class labels using different thresholds. Usually, when the predicted probability is greater than a threshold, the sample is classified as a positive example, otherwise it is classified as a negative example.
For each threshold, calculate the corresponding precision and recall.
Plot the precision and recall at each threshold on the same graph to form a precision curve.
According to the shape and changing trend of the accuracy rate curve, an appropriate threshold can be selected to achieve the required performance requirements.
By observing the precision rate curve, we can determine the optimal threshold according to the needs to balance the precision rate and recall rate. Higher precision means fewer false positives, while higher recall means fewer false negatives. Depending on specific business needs and cost trade-offs, an appropriate operating point or threshold can be chosen on the curve.
Precision curves are often used together with recall curves to provide a more comprehensive analysis of classifier performance and to help evaluate and compare the performance of different models.


【Recall Curve】
Recall Curve is a visualization tool for evaluating the recall performance of binary classification models at different thresholds. It helps us understand the performance of the model at different thresholds by plotting the relationship between the recall rate and the corresponding precision rate at different thresholds.
The recall rate (Recall) refers to the proportion of the number of samples that are correctly predicted as positive examples to the number of samples that are actually positive examples. The recall rate is also called the sensitivity (Sensitivity) or the true positive rate (True Positive Rate).
The steps to draw the recall curve are as follows:
Convert predicted probabilities to binary class labels using different thresholds. Usually, when the predicted probability is greater than a threshold, the sample is classified as a positive example, otherwise it is classified as a negative example.
For each threshold, the corresponding recall and corresponding precision are calculated.
Plot the recall and precision at each threshold on the same graph to form a recall curve.
According to the shape and changing trend of the recall rate curve, an appropriate threshold can be selected to achieve the desired performance requirements.
By observing the recall curve, we can determine the optimal threshold according to the needs to balance the recall and precision. Higher recall means fewer false negatives, while higher precision means fewer false positives. Depending on specific business needs and cost trade-offs, an appropriate operating point or threshold can be chosen on the curve.
Recall curves are often used together with Precision Curves to provide a more comprehensive analysis of classifier performance and to help evaluate and compare the performance of different models.


【F1 value curve】
The F1-score curve is a visualization tool for evaluating the performance of binary classification models at different thresholds. It helps us understand the overall performance of the model by plotting the relationship between Precision, Recall and F1 scores at different thresholds.
The F1 score is the harmonic mean of precision and recall, which takes into account both performance metrics. The F1 value curve can help us determine a balance point between different precision rates and recall rates to choose the best threshold.
The steps to draw the F1 value curve are as follows:
Convert predicted probabilities to binary class labels using different thresholds. Usually, when the predicted probability is greater than a threshold, the sample is classified as a positive example, otherwise it is classified as a negative example.
For each threshold, the corresponding precision, recall and F1 score are calculated.
The precision rate, recall rate and F1 score under each threshold are plotted on the same graph to form an F1 value curve.
According to the shape and changing trend of the F1 value curve, an appropriate threshold can be selected to achieve the required performance requirements.
F1-value curves are often used together with receiver operating characteristic curves (ROC curves) to help evaluate and compare the performance of different models. They provide a more comprehensive classifier performance analysis, and can select appropriate models and threshold settings according to specific application scenarios.

【loss curve】

On the whole, the m-series model is the best among the three different parameter magnitude models, and the performance of the n-series and s-series models is similar.

Next, we need to prune the three models. Here we use a very useful third-party module torch_pruning. The official project address is here, as follows:

The installation method is very simple as follows:

pip install torch-pruning
or
git clone https://github.com/VainF/Torch-Pruning.git

In structural pruning, a “group” is defined as the smallest unit that can be removed in a deep network. These groups consist of multiple interdependent layers that need to be pruned together to maintain the integrity of the generated structure. However, there are often complex dependencies among the layers of deep networks, which makes structure pruning a challenging task. This work addresses this challenge by introducing an automated mechanism called “DepGraph”. DepGraph allows seamless parameter grouping and facilitates pruning in various types of deep networks.

The official provides many available examples, as follows:

【Naive pruning】

To demonstrate the implications of dependencies, let’s try structured pruning on ResNet-18. The following code snippet attempts to remove channels indexed by 0 and 1 from the first model.conv1:

from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18(pretrained=True).eval()
tp.prune_conv_out_channels(model.conv1, idxs=[0,1]) # remove channel 0 and channel 1
output = model(torch.randn(1,3,224,224)) # test


output
ResNet(
  (conv1): Conv2d(3, 62, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
...

【An improved version】

In fact, the dependencies in the above case are much more complex than what we have already observed. Let’s improve our code and see what happens if we deal with BN and Conv.

from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18(pretrained=True).eval()
tp.prune_conv_out_channels(model.conv1, idxs=[0,1])
tp.prune_batchnorm_out_channels(model.bn1, idxs=[0,1])
tp.prune_batchnorm_in_channels(model.layer1[0].conv1, idxs=[0,1])
output = model(torch.randn(1,3,224,224)) 

【A Minimal Example】

import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18(pretrained=True).eval()

# 1. build dependency graph for resnet18
DG = tp.DependencyGraph().build_dependency(model, example_inputs=torch.randn(1,3,224,224))

# 2. Specify the to-be-pruned channels. Here we prune those channels indexed by [2, 6, 9].
group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=[2, 6, 9] )

# 3. prune all grouped layers that are coupled with model.conv1 (included).
if DG.check_pruning_group(group): # avoid full pruning, i.e., channels=0.
    group.prune()
    
# 4. Save & Load
model.zero_grad() # We don't want to store gradient information
torch.save(model, 'model.pth') # without .state_dict
model = torch.load('model.pth') # load the model object

The above example demonstrates a basic pruning pipeline using DepGraph. The target layer resnet.conv1 is coupled to multiple layers, which needs to be removed simultaneously in structure pruning. Let's print the group and observe how pruning operations "trigger" other pruning operations. In the following output, A=>B means that pruning operation A triggers pruning operation B. group[0] represents the pruning root in DG.get_pruning_group.

--------------------------------
          Pruning Group
--------------------------------
[0] prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on conv1 (Conv2d( 3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)), idxs=[2, 6, 9] (Pruning Root)
[1] prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on bn1 (BatchNorm2d( 64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs=[2, 6, 9]
[2] prune_out_channels on bn1 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on _ElementWiseOp_20(ReluBackward0), idxs=[2, 6, 9]
[3] prune_out_channels on _ElementWiseOp_20(ReluBackward0) => prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0), idxs=[2, 6, 9]
[4] prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0) => prune_out_channels on _ElementWiseOp_18(AddBackward0), idxs=[2, 6, 9]
[5] prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0) => prune_in_channels on layer1.0.conv1 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias =False)), idxs=[2, 6, 9]
[6] prune_out_channels on _ElementWiseOp_18(AddBackward0) => prune_out_channels on layer1.0.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs=[2, 6, 9]
[7] prune_out_channels on _ElementWiseOp_18(AddBackward0) => prune_out_channels on _ElementWiseOp_17(ReluBackward0), idxs=[2, 6, 9]
[8] prune_out_channels on _ElementWiseOp_17(ReluBackward0) => prune_out_channels on _ElementWiseOp_16(AddBackward0), idxs=[2, 6, 9]
[9] prune_out_channels on _ElementWiseOp_17(ReluBackward0) => prune_in_channels on layer1.1.conv1 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias =False)), idxs=[2, 6, 9]
[10] prune_out_channels on _ElementWiseOp_16(AddBackward0) => prune_out_channels on layer1.1.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs=[2, 6, 9]
[11] prune_out_channels on _ElementWiseOp_16(AddBackward0) => prune_out_channels on _ElementWiseOp_15(ReluBackward0), idxs=[2, 6, 9]
[12] prune_out_channels on _ElementWiseOp_15(ReluBackward0) => prune_in_channels on layer2.0.downsample.0 (Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)), idxs =[2, 6, 9]
[13] prune_out_channels on _ElementWiseOp_15(ReluBackward0) => prune_in_channels on layer2.0.conv1 (Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias =False)), idxs=[2, 6, 9]
[14] prune_out_channels on layer1.1.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.1.conv2 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[15] prune_out_channels on layer1.0.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.0.conv2 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
--------------------------------

【High-level Pruners】

Leveraging DependencyGraph, we have developed several advanced pruners in this repository to facilitate easy pruning. By specifying the desired channel sparsity, the entire model can be pruned and fine-tuned with your own training code. For more information on the process, see this tutorial, which shows how to implement a slim trimmer from scratch. Additionally, you can find more practical examples in benchmarks/main.py.

import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18(pretrained=True)

#Importancecriteria
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.TaylorImportance()

ignored_layers = []
for m in model.modules():
    if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
        ignored_layers.append(m) # DO NOT prune the final classifier!

iterative_steps = 5 # progressive pruning
pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs,
    importance=imp,
    iterative_steps=iterative_steps,
    ch_sparsity=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
    ignored_layers=ignored_layers,
)

base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
for i in range(iterative_steps):
    if isinstance(imp, tp.importance.TaylorImportance):
        # Taylor expansion requires gradients for importance estimation
        loss = model(example_inputs).sum() # a dummy loss for TaylorImportance
        loss.backward() # before pruner.step()
    pruner. step()
    macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
    # finetune your model here
    #finetune(model)
    # ...

There are many other functional examples, so I won’t go into details here, you can refer to them. Here I use the official examples to complete the pruning of yolov5n/s/m three different parameter magnitude models. The result after pruning is completed is as follows:

Next, I want to use the pruned model directly for evaluation and testing. If nothing else, the results should be very poor. Let’s take a brief look first.

【yolov5n_layer_pruning】

【yolov5s_layer_pruning】

【yolov5m_layer_pruning】

It really is appalling. It is impossible to directly use the pruned model file, which destroys the original and complete model structure, and makes the original learned knowledge invalid.

Next, fine-tuning training needs to be performed based on the pruned structure. Here I also kept the same parameter settings as the original model training, as follows:

 #yolov5n
python3 train.py --weights yolov5n_layer_pruning.pt --pt --name yolov5n_pruning --epochs 100 --batch-size 4 --img-size 416

#yolov5s
python3 train.py --weights yolov5s_layer_pruning.pt --pt --name yolov5s_pruning --epochs 100 --batch-size 4 --img-size 416

#yolov5m
python3 train.py --weights yolov5m_layer_pruning.pt --pt --name yolov5m_pruning --epochs 100 --batch-size 4 --img-size 416

In fact, there is no need to train for 100 epochs here, but I want to keep the same parameter settings by default and wait for a while to see the result records.

【yolov5n_pruning】

【yolov5s_pruning】

【yolov5m_pruning】

Here, judging from the evaluation results: n

【F1 value】

【Accuracy】

【Recall Rate】

【loss】

The above three sets of pruning experimental results are based on 30% pruning. The results can be seen that even the effect after pruning is better than the original model, which also shows that there are Considerable amount of parameter redundancy.

Next, we want to further explore the impact of different levels of pruning on model performance. As a lesson learned from the past, when writing a CSDN blog here, I dare not write too much content in one article, otherwise it will be very sad if the page crashes suddenly. . . . . .

I put the content of this part in the next blog post, as follows:
“Developing and constructing a tea bud detection and identification model based on YOLOv5n/s/m models with different parameter levels. It is expected that pruning technology will be used to lightweight the model and explore the impact of model performance under different pruning levels [continued]”