Target detection algorithm improvement series: Backbone replaced by VanillaNet

Introduction to VanillaNet

Introduction: VanillaNet is a neural network architecture that incorporates elegance into its design. By avoiding complex operations such as high depth, shortcut, and self-attention, VanillaNet is simple and powerful. Each layer is crafted to be compact and straightforward, with non-linear activation functions pruned after training to restore the original framework. VanillaNet overcomes the challenges of inherent complexity, making it ideal for resource-constrained environments. Its easy-to-understand and highly simplified architecture opens up new possibilities for efficient deployment. A large number of experiments show that the performance of VanillaNet is comparable to that of the famous deep neural network and Transformer, demonstrating the power of minimalism in deep learning. The work of VanillaNet has great potential to redefine the landscape and challenge the status quo of basic models, opening up a new path for elegant and effective model design.

VanillaNet-6 model

VanillaNet code implementation

#Copyright (C) 2023. Huawei Technologies Co., Ltd. All rights reserved.

#This program is free software; you can redistribute it and/or modify it under the terms of the MIT License.

#This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the MIT License for more details.

import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import weight_init, DropPath
import numpy as np

__all__ = ['vanillanet_5', 'vanillanet_6', 'vanillanet_7', 'vanillanet_8', 'vanillanet_9', 'vanillanet_10', 'vanillanet_11', 'vanillanet_12' , 'vanillanet_13', 'vanillanet_13_x1_5', 'vanillanet_13_x1_5_ada_pool']

class activation(nn.ReLU):
    def __init__(self, dim, act_num=3, deploy=False):
        super(activation, self).__init__()
        self.deploy = deploy
        self.weight = torch.nn.Parameter(torch.randn(dim, 1, act_num*2 + 1, act_num*2 + 1))
        self.bias = None
        self.bn = nn.BatchNorm2d(dim, eps=1e-6)
        self.dim = dim
        self.act_num = act_num
        weight_init.trunc_normal_(self.weight, std=.02)

    def forward(self, x):
        if self.deploy:
            return torch.nn.functional.conv2d(
                super(activation, self).forward(x),
                self.weight, self.bias, padding=(self.act_num*2 + 1)//2, groups=self.dim)
        else:
            return self.bn(torch.nn.functional.conv2d(
                super(activation, self).forward(x),
                self.weight, padding=self.act_num, groups=self.dim))

    def _fuse_bn_tensor(self, weight, bn):
        kernel=weight
        running_mean = bn.running_mean
        running_var = bn.running_var
        gamma = bn.weight
        beta = bn.bias
        eps = bn.eps
        std = (running_var + eps).sqrt()
        t = (gamma / std).reshape(-1, 1, 1, 1)
        return kernel * t, beta + (0 - running_mean) * gamma / std
    
    def switch_to_deploy(self):
        if not self.deploy:
            kernel, bias = self._fuse_bn_tensor(self.weight, self.bn)
            self.weight.data = kernel
            self.bias = torch.nn.Parameter(torch.zeros(self.dim))
            self.bias.data = bias
            self.__delattr__('bn')
            self.deploy = True


class Block(nn.Module):
    def __init__(self, dim, dim_out, act_num=3, stride=2, deploy=False, ada_pool=None):
        super().__init__()
        self.act_learn = 1
        self.deploy = deploy
        if self.deploy:
            self.conv = nn.Conv2d(dim, dim_out, kernel_size=1)
        else:
            self.conv1 = nn.Sequential(
                nn.Conv2d(dim, dim, kernel_size=1),
                nn.BatchNorm2d(dim, eps=1e-6),
            )
            self.conv2 = nn.Sequential(
                nn.Conv2d(dim, dim_out, kernel_size=1),
                nn.BatchNorm2d(dim_out, eps=1e-6)
            )

        if not ada_pool:
            self.pool = nn.Identity() if stride == 1 else nn.MaxPool2d(stride)
        else:
            self.pool = nn.Identity() if stride == 1 else nn.AdaptiveMaxPool2d((ada_pool, ada_pool))

        self.act = activation(dim_out, act_num)
 
    def forward(self, x):
        if self.deploy:
            x = self.conv(x)
        else:
            x = self.conv1(x)
            x = torch.nn.functional.leaky_relu(x,self.act_learn)
            x = self.conv2(x)

        x = self.pool(x)
        x = self.act(x)
        return x

    def _fuse_bn_tensor(self, conv, bn):
        kernel = conv.weight
        bias = conv.bias
        running_mean = bn.running_mean
        running_var = bn.running_var
        gamma = bn.weight
        beta = bn.bias
        eps = bn.eps
        std = (running_var + eps).sqrt()
        t = (gamma / std).reshape(-1, 1, 1, 1)
        return kernel * t, beta + (bias - running_mean) * gamma / std
    
    def switch_to_deploy(self):
        if not self.deploy:
            kernel, bias = self._fuse_bn_tensor(self.conv1[0], self.conv1[1])
            self.conv1[0].weight.data = kernel
            self.conv1[0].bias.data = bias
            # kernel, bias = self.conv2[0].weight.data, self.conv2[0].bias.data
            kernel, bias = self._fuse_bn_tensor(self.conv2[0], self.conv2[1])
            self.conv = self.conv2[0]
            self.conv.weight.data = torch.matmul(kernel.transpose(1,3), self.conv1[0].weight.data.squeeze(3).squeeze(2)).transpose(1,3)
            self.conv.bias.data = bias + (self.conv1[0].bias.data.view(1,-1,1,1)*kernel).sum(3).sum(2).sum(1 )
            self.__delattr__('conv1')
            self.__delattr__('conv2')
            self.act.switch_to_deploy()
            self.deploy = True
    

class VanillaNet(nn.Module):
    def __init__(self, in_chans=3, num_classes=1000, dims=[96, 192, 384, 768],
                 drop_rate=0, act_num=3, strides=[2,2,2,1], deploy=False, ada_pool=None, **kwargs):
        super().__init__()
        self.deploy = deploy
        if self.deploy:
            self.stem = nn.Sequential(
                nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
                activation(dims[0], act_num)
            )
        else:
            self.stem1 = nn.Sequential(
                nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
                nn.BatchNorm2d(dims[0], eps=1e-6),
            )
            self.stem2 = nn.Sequential(
                nn.Conv2d(dims[0], dims[0], kernel_size=1, stride=1),
                nn.BatchNorm2d(dims[0], eps=1e-6),
                activation(dims[0], act_num)
            )

        self.act_learn = 1

        self.stages = nn.ModuleList()
        for i in range(len(strides)):
            if not ada_pool:
                stage = Block(dim=dims[i], dim_out=dims[i + 1], act_num=act_num, stride=strides[i], deploy=deploy)
            else:
                stage = Block(dim=dims[i], dim_out=dims[i + 1], act_num=act_num, stride=strides[i], deploy=deploy, ada_pool=ada_pool[i])
            self.stages.append(stage)
        self.depth = len(strides)

        self.apply(self._init_weights)
        self.channel = [i.size(1) for i in self.forward(torch.randn(1, 3, 640, 640))]

    def _init_weights(self, m):
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            weight_init.trunc_normal_(m.weight, std=.02)
            nn.init.constant_(m.bias, 0)

    def change_act(self, m):
        for i in range(self.depth):
            self.stages[i].act_learn = m
        self.act_learn = m

    def forward(self, x):
        input_size = x.size(2)
        scale = [4, 8, 16, 32]
        features = [None, None, None, None]
        if self.deploy:
            x = self.stem(x)
        else:
            x = self.stem1(x)
            x = torch.nn.functional.leaky_relu(x,self.act_learn)
            x = self.stem2(x)
        if input_size // x.size(2) in scale:
            features[scale.index(input_size // x.size(2))] = x
        for i in range(self.depth):
            x = self.stages[i](x)
            if input_size // x.size(2) in scale:
                features[scale.index(input_size // x.size(2))] = x
        return features

    def _fuse_bn_tensor(self, conv, bn):
        kernel = conv.weight
        bias = conv.bias
        running_mean = bn.running_mean
        running_var = bn.running_var
        gamma = bn.weight
        beta = bn.bias
        eps = bn.eps
        std = (running_var + eps).sqrt()
        t = (gamma / std).reshape(-1, 1, 1, 1)
        return kernel * t, beta + (bias - running_mean) * gamma / std
    
    def switch_to_deploy(self):
        if not self.deploy:
            self.stem2[2].switch_to_deploy()
            kernel, bias = self._fuse_bn_tensor(self.stem1[0], self.stem1[1])
            self.stem1[0].weight.data = kernel
            self.stem1[0].bias.data = bias
            kernel, bias = self._fuse_bn_tensor(self.stem2[0], self.stem2[1])
            self.stem1[0].weight.data = torch.einsum('oi,icjk->ocjk', kernel.squeeze(3).squeeze(2), self.stem1[0].weight.data)
            self.stem1[0].bias.data = bias + (self.stem1[0].bias.data.view(1,-1,1,1)*kernel).sum(3).sum(2). sum(1)
            self.stem = torch.nn.Sequential(*[self.stem1[0], self.stem2[2]])
            self.__delattr__('stem1')
            self.__delattr__('stem2')

            for i in range(self.depth):
                self.stages[i].switch_to_deploy()

            self.deploy = True

def update_weight(model_dict, weight_dict):
    idx, temp_dict = 0, {<!-- -->}
    for k, v in weight_dict.items():
        if k in model_dict.keys() and np.shape(model_dict[k]) == np.shape(v):
            temp_dict[k] = v
            idx + = 1
    model_dict.update(temp_dict)
    print(f'loading weights... {<!-- -->idx}/{<!-- -->len(model_dict)} items')
    return model_dict

def vanillanet_5(pretrained='',in_22k=False, **kwargs):
    model = VanillaNet(dims=[128*4, 256*4, 512*4, 1024*4], strides=[2,2,2], **kwargs)
    if pretrained:
        weights = torch.load(pretrained)['model_ema']
        model.load_state_dict(update_weight(model.state_dict(), weights))
    return model

def vanillanet_6(pretrained='',in_22k=False, **kwargs):
    model = VanillaNet(dims=[128*4, 256*4, 512*4, 1024*4, 1024*4], strides=[2,2,2,1], **kwargs)
    if pretrained:
        weights = torch.load(pretrained)['model_ema']
        model.load_state_dict(update_weight(model.state_dict(), weights))
    return model

def vanillanet_7(pretrained='',in_22k=False, **kwargs):
    model = VanillaNet(dims=[128*4, 128*4, 256*4, 512*4, 1024*4, 1024*4], strides=[1,2,2,2,1], **kwargs)
    if pretrained:
        weights = torch.load(pretrained)['model_ema']
        model.load_state_dict(update_weight(model.state_dict(), weights))
    return model

def vanillanet_8(pretrained='', in_22k=False, **kwargs):
    model = VanillaNet(dims=[128*4, 128*4, 256*4, 512*4, 512*4, 1024*4, 1024*4], strides=[1,2,2,1,2,1 ], **kwargs)
    if pretrained:
        weights = torch.load(pretrained)['model_ema']
        model.load_state_dict(update_weight(model.state_dict(), weights))
    return model

def vanillanet_9(pretrained='', in_22k=False, **kwargs):
    model = VanillaNet(dims=[128*4, 128*4, 256*4, 512*4, 512*4, 512*4, 1024*4, 1024*4], strides=[1,2,2,1 ,1,2,1], **kwargs)
    if pretrained:
        weights = torch.load(pretrained)['model_ema']
        model.load_state_dict(update_weight(model.state_dict(), weights))
    return model

def vanillanet_10(pretrained='', in_22k=False, **kwargs):
    model = VanillaNet(
        dims=[128*4, 128*4, 256*4, 512*4, 512*4, 512*4, 512*4, 1024*4, 1024*4],
        strides=[1,2,2,1,1,1,2,1],
        **kwargs)
    if pretrained:
        weights = torch.load(pretrained)['model_ema']
        model.load_state_dict(update_weight(model.state_dict(), weights))
    return model

def vanillanet_11(pretrained='', in_22k=False, **kwargs):
    model = VanillaNet(
        dims=[128*4, 128*4, 256*4, 512*4, 512*4, 512*4, 512*4, 512*4, 1024*4, 1024*4],
        strides=[1,2,2,1,1,1,1,2,1],
        **kwargs)
    if pretrained:
        weights = torch.load(pretrained)['model_ema']
        model.load_state_dict(update_weight(model.state_dict(), weights))
    return model

def vanillanet_12(pretrained='', in_22k=False, **kwargs):
    model = VanillaNet(
        dims=[128*4, 128*4, 256*4, 512*4, 512*4, 512*4, 512*4, 512*4, 512*4, 1024*4, 1024*4],
        strides=[1,2,2,1,1,1,1,1,2,1],
        **kwargs)
    if pretrained:
        weights = torch.load(pretrained)['model_ema']
        model.load_state_dict(update_weight(model.state_dict(), weights))
    return model

def vanillanet_13(pretrained='', in_22k=False, **kwargs):
    model = VanillaNet(
        dims=[128*4, 128*4, 256*4, 512*4, 512*4, 512*4, 512*4, 512*4, 512*4, 512*4, 1024*4, 1024*4 ],
        strides=[1,2,2,1,1,1,1,1,1,2,1],
        **kwargs)
    if pretrained:
        weights = torch.load(pretrained)['model_ema']
        model.load_state_dict(update_weight(model.state_dict(), weights))
    return model

def vanillanet_13_x1_5(pretrained='', in_22k=False, **kwargs):
    model = VanillaNet(
        dims=[128*6, 128*6, 256*6, 512*6, 512*6, 512*6, 512*6, 512*6, 512*6, 512*6, 1024*6, 1024*6 ],
        strides=[1,2,2,1,1,1,1,1,1,2,1],
        **kwargs)
    if pretrained:
        weights = torch.load(pretrained)['model_ema']
        model.load_state_dict(update_weight(model.state_dict(), weights))
    return model

def vanillanet_13_x1_5_ada_pool(pretrained='', in_22k=False, **kwargs):
    model = VanillaNet(
        dims=[128*6, 128*6, 256*6, 512*6, 512*6, 512*6, 512*6, 512*6, 512*6, 512*6, 1024*6, 1024*6 ],
        strides=[1,2,2,1,1,1,1,1,1,2,1],
        ada_pool=[0,40,20,0,0,0,0,0,0,10,0],
        **kwargs)
    if pretrained:
        weights = torch.load(pretrained)['model_ema']
        model.load_state_dict(update_weight(model.state_dict(), weights))
    return model

if __name__ == '__main__':
    inputs = torch.randn((1, 3, 640, 640))
    model = vanillanet_10()
    # weights = torch.load('vanillanet_5.pth')['model_ema']
    # model.load_state_dict(update_weight(model.state_dict(), weights))
    pred = model(inputs)
    for i in pred:
        print(i.size())

Backbone replacement

yolo.py modification

def parse_model function

def parse_model(d, ch): # model_dict, input_channels(3)
    # Parse a YOLOv5 model.yaml dictionary
    LOGGER.info(f"\\
{<!-- -->'':>3}{<!-- -->'from':>18}{<!-- -- >'n':>3}{<!-- -->'params':>10} {<!-- -->'module':<40}{<!-- - ->'arguments':<30}")
    anchors, nc, gd, gw, act = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple'], d.get( 'activation')
    if act:
        Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU()
        LOGGER.info(f"{<!-- -->colorstr('activation:')} {<!-- -->act}") # print
    na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
    no = na * (nc + 5) # number of outputs = anchors * (classes + 5)

    is_backbone = False
    layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
    for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
        try:
            t = m
            m = eval(m) if isinstance(m, str) else m # eval strings
        except:
            pass
        for j, a in enumerate(args):
            with contextlib.suppress(NameError):
                try:
                    args[j] = eval(a) if isinstance(a, str) else a # eval strings
                except:
                    args[j] = a

        n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain
        if m in {<!-- -->
                Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,
                BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x}:
            c1, c2 = ch[f], args[0]
            if c2 != no: # if not output
                c2 = make_divisible(c2 * gw, 8)

            args = [c1, c2, *args[1:]]
            if m in {<!-- -->BottleneckCSP, C3, C3TR, C3Ghost, C3x}:
                args.insert(2, n) # number of repeats
                n=1
        elif m is nn.BatchNorm2d:
            args = [ch[f]]
        elif m is Concat:
            c2 = sum(ch[x] for x in f)
        # TODO: channel, gw, gd
        elif m in {<!-- -->Detect, Segment}:
            args.append([ch[x] for x in f])
            if isinstance(args[1], int): # number of anchors
                args[1] = [list(range(args[1] * 2))] * len(f)
            if m is Segment:
                args[3] = make_divisible(args[3] * gw, 8)
        elif m is Contract:
            c2 = ch[f] * args[0] ** 2
        elif m is Expand:
            c2 = ch[f] // args[0] ** 2
        elif isinstance(m, str):
            t = m
            m = timm.create_model(m, pretrained=args[0], features_only=True)
            c2 = m.feature_info.channels()
        elif m in {<!-- -->vanillanet_5, vanillanet_6}: #Add Backbone
            m = m(*args)
            c2 = m.channel
        else:
            c2 = ch[f]
        if isinstance(c2, list):
            is_backbone = True
            m_ = m
            m_.backbone = True
        else:
            m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
            t = str(m)[8:-2].replace('__main__.', '') # module type
        np = sum(x.numel() for x in m_.parameters()) # number params
        m_.i, m_.f, m_.type, m_.np = i + 4 if is_backbone else i, f, t, np # attach index, 'from' index, type, number params
        LOGGER.info(f'{<!-- -->i:>3}{<!-- -->str(f):>18}{<!-- -->n_:>3}{ <!-- -->np:10.0f} {<!-- -->t:<40}{<!-- -->str(args):<30}') # print
        save.extend(x % (i + 4 if is_backbone else i) for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
        layers.append(m_)
        if i == 0:
            ch = []
        if isinstance(c2, list):
            ch.extend(c2)
            for _ in range(5 - len(ch)):
                ch.insert(0, 0)
        else:
            ch.append(c2)
    return nn.Sequential(*layers), sorted(save)

def _forward_once function

def _forward_once(self, x, profile=False, visualize=False):
    y, dt = [], [] # outputs
    for m in self.model:
        if m.f != -1: # if not from previous layer
            x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
        if profile:
            self._profile_one_layer(m, x, dt)
        if hasattr(m, 'backbone'):
            x = m(x)
            for _ in range(5 - len(x)):
                x.insert(0, None)
            for i_idx, i in enumerate(x):
                if i_idx in self.save:
                    y.append(i)
                else:
                    y.append(None)
            x = x[-1]
        else:
            x = m(x) # run
            y.append(x if m.i in self.save else None) # save output
        if visualize:
            feature_visualization(x, m.type, m.i, save_dir=visualize)
    return x

Create .yaml configuration file

# YOLOv5  by Ultralytics, GPL-3.0 license
#Parameters
nc: 80 # number of classes
depth_multiple: 0.33 # model depth multiple
width_multiple: 0.25 # layer channel multiple
anchors:
  - [10,13, 16,30, 33,23] # P3/8
  - [30,61, 62,45, 59,119] # P4/16
  - [116,90, 156,198, 373,326] # P5/32

#0-P1/2
#1-P2/4
#2-P3/8
#3-P4/16
#4-P5/32

# YOLOv5 v6.0 backbone
backbone:
  # [from, number, module, args]
  [[-1, 1, vanillanet_5, [False]], # 4
   [-1, 1, SPPF, [1024, 5]], # 5
  ]

# YOLOv5 v6.0 head
head:
  [[-1, 1, Conv, [512, 1, 1]], # 6
   [-1, 1, nn.Upsample, [None, 2, 'nearest']], # 7
   [[-1, 3], 1, Concat, [1]], # cat backbone P4 8
   [-1, 3, C3, [512, False]], # 9

   [-1, 1, Conv, [256, 1, 1]], # 10
   [-1, 1, nn.Upsample, [None, 2, 'nearest']], # 11
   [[-1, 2], 1, Concat, [1]], # cat backbone P3 12
   [-1, 3, C3, [256, False]], # 13 (P3/8-small)

   [-1, 1, Conv, [256, 3, 2]], # 14
   [[-1, 10], 1, Concat, [1]], # cat head P4 15
   [-1, 3, C3, [512, False]], # 16 (P4/16-medium)

   [-1, 1, Conv, [512, 3, 2]], # 17
   [[-1, 5], 1, Concat, [1]], # cat head P5 18
   [-1, 3, C3, [1024, False]], # 19 (P5/32-large)

   [[13, 16, 19], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
  ]