Transformer-based decode target detection framework (modify DETR source code)

Tip: Transformer structure target detection decoder, including loss calculation, with source code attached

Article directory

  • Preface
  • 1. Interpretation of main function code
    • 1. Understanding the overall structure
    • 2. Main function code interpretation
    • 3. Source code link
  • 2. Interpretation of decode module code
    • 1. Interpretation of decoded TransformerDec module code
    • 2. Interpretation of decoded TransformerDecoder module code
    • 3. Interpretation of decoded DecoderLayer module code
  • 3. Decode module training demo code interpretation
    • 1. Decode data input format
    • 2. Interpretation of decoding training demo code
  • 4. Decode module prediction demo code interpretation
    • 1. Forecast data input format
    • 2. Interpretation of decoding prediction demo code
  • 5. Interpretation of losses module code
    • 1. Matcher initialization
    • 2. Binary matching matcher code interpretation
    • 3. Interpretation of num_classes parameter
    • 4. Interpretation of demo code of losses

Foreword

Recently, I revisited the DETR model, and I feel more and more that the structure of the DETR model is exquisite. It is different from the anchor base and anchor free designs. It directly uses the 100 box to give the prediction results, uses the learnable learn query for deep search, and uses the binary matching method to train the model. To this end, I extracted a series of modules such as decoding and loss calculation based on the detr source code, and reconstructed, modified, and integrated a set of decoding and loss implementation frameworks. This framework can be applied to any backbone feature extraction framework to achieve complete training and prediction. , I also have corresponding demo guide to use my framework. So, next, I will introduce the source code of the framework in full. At the same time, I open sourced this source code and uploaded it to github for readers’ reference.

1. Main function code interpretation

1. Understanding the overall structure

Before introducing the main function code, let me first talk about the overall framework structure. The framework contains 2 folders, one is the losses folder, which is used to process loss calculations, and the other is the obj_det file, which is used for the transformer decoding module. The source code of this module is modified in The detr model also includes main.py. This file is the overall decoding and loss calculation demo code, as shown below.

2. Main function code interpretation

This code is actually a demo in which I randomly created label target data, backbone feature extraction data, and position encoding data so that it can run normally. The code is as follows:

import torch
from obj_det.transformer_obj import TransformerDec
from losses.matcher import HungarianMatcher
from losses.loss import SetCriterion

if __name__ == '__main__':


    Model = TransformerDec(d_model=256, output_intermediate_dec=True, num_classes=4)

    num_classes = 4 # categories + 1
    matcher = HungarianMatcher(cost_class=1, cost_bbox=5, cost_giou=2) # Binary matching weights assigned to different tasks
    losses = ['labels', 'boxes', 'cardinality'] # Task of calculating loss
    weight_dict = {
   <!-- -->'loss_ce': 1, 'loss_bbox': 5, 'loss_giou': 2} # Set the weight for the last dert
    criterion = SetCriterion(num_classes, matcher=matcher, weight_dict=weight_dict, eos_coef=0.1, losses=losses)

    # Using iter below, I constructed the virtual model encoding data and data loading label data
    src = torch.rand((391, 2, 256))
    pos_embed = torch.ones((391, 1, 256))

    #Create real target data
    target1 = {
   <!-- -->'boxes':torch.rand((5,4)),'labels':torch.tensor([1,3,2,1,2])}
    target2 = {
   <!-- -->'boxes': torch.rand((3, 4)), 'labels': torch.tensor([1, 1, 2])}
    target = [target1, target2]

    res = Model(src, pos_embed)
    losses = criterion(res, target)
    print(losses)

As shown below:

3. Source code link

Github source code link: click here
Baidu Netdisk source code link:
Link: https://pan.baidu.com/s/1r9q_et6AVT6Rdx7_2-7X5w
Extraction code: detr

2. Interpretation of decode module code

This module mainly uses the transform method to decode backbone extracted features, mainly using learn query and other related trike and transform decoding methods.
I mainly introduce the TransformerDec, TransformerDecoder, and DecoderLayer modules, which are included in sequence, or the latter is a component of the former.

1. Interpretation of decoded TransformerDec module code

The general idea of this class is that it includes learn query embedding, decoding transform module calls, header prediction logit and boxes, etc. It is the content to implement decoding and prediction. The parameters or explanations of this module have been annotated, and readers can view them by themselves. The code is as follows:

class TransformerDec(nn.Module):
    '''
    d_model=512, how many dimensions are used for representation, which are actually encoding output expression dimensions.
    nhead=8, how many heads are there?
    num_queries=100, the number of target queries, query can be learned
    num_decoder_layers=6, number of decoding loop layers
    dim_feedforward=2048, 2 nn.Linear changes similar to FFN
    dropout=0.1,
    activation="relu",
    normalize_before=False, decoding structure uses 2 methods, default False uses post decoding structure
    output_intermediate_dec=False, if True, the intermediate layer decoding result is saved (that is, each decoding layer result is saved), if False, only the last result is saved, training is True, and inference is False
    num_classes: The number of num_classes is related to the data format. If category id=1 indicates the first category, then num_classes=actual number of categories + 1. If id=0 indicates the first category, then num_classes=actual number of categories.

    Additional explanation, coco category id starts with 1. If there are three classes named [dog, cat, pig], batch=2, then the parameter num_classes=4 means 3 classes + 1 background.
    The model output src_logits=[2,100,5] will produce one more prediction, target_classes is set to [2,100], and its value is 4 (this value is the background, and the class values are 1, 2, and 3),
    Then there is no value 0 in target_classes. I understand that the model does not do any operation on class 0, which is an invalid value. The model only performs loss calculation on 1, 2, 3, and 4, but 4 will have more background.
    The authors use a weight of 0.1 to avoid over-influence of their background.

    forward return: Returns a dictionary containing {
   <!-- -->
    'pred_logits':[], # is a list, the format is [b,100,num_classes + 2]
    'pred_boxes':[], # is a list, the format is [b,100,4]
    'aux_outputs'[{
   <!-- -->},...] # is a list, the elements are dictionaries, and each dictionary is {
   <!-- -->'pred_logits':[],'pred_boxes':[]}, the format is the same as above

    }

    '''

    def __init__(self, d_model=512, nhead=8, num_queries=100, num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
                 activation="relu", normalize_before=False, output_intermediate_dec=False, num_classes=1):
        super().__init__()

        self.num_queries = num_queries
        self.query_embed = nn.Embedding(num_queries, d_model) # Consistent with the encoding output expression dimension
        self.output_intermediate_dec = output_intermediate_dec

        decoder_layer = DecoderLayer(d_model, nhead, dim_feedforward,
                                                dropout, activation, normalize_before)
        decoder_norm = nn.LayerNorm(<