EfficientViT: Multi-scale linear attention for high-resolution dense prediction

Title: EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction

Paper: https://arxiv.org/abs/2205.14756

Chinese version: [Reading Paper] EfficientViT: Enhanced Linear Attention for High-Resolution Low-Computation transforms softmax attention into linear attention_Shaner’s Blog-CSDN Blog

Code: https://codeload.github.com/mit-han-lab/efficientvit/zip/refs/heads/master

Table of Contents

1. Abstract

2. Main contributions

3. Methodology

3.1 Multi-Scale Linear Attention

3.2 EfficientViT architecture

4. Experiment

4.1 Ablation studies

4.2 Semantic segmentation experiment

5. Summary


1. Summary

Research background: High resolution Dense prediction enables many attractive real-world applications, such as computational photography, autonomous driving, etc. However, the huge computational cost makes deployment the most Hardware difficulties in developing advanced high-resolution dense predictive models.

Main work: This paper proposes a new high-resolution vision model with multi-scale linear attention – EfficientViT. Unlike previous high-resolution dense prediction models that relied on massive softmax attention, hardware-inefficient large-kernel convolutions, or complex topologies to achieve good performance, multi-scale linear attention only requires lightweight and hardware-efficient operations. This enables global receptive fields and multi-scale learning (two ideal characteristics for high-resolution dense predictions).

Research Results: As a result, EfficientViT provides significant performance improvements over previous state-of-the-art models on a variety of hardware platforms, including mobile CPUs, edge GPUs, and cloud GPUs. Without performance penalty on Cityscapes (dataset), EfficientViT provides up to 13.9x and 6.2x GPU latency reduction over SegFormer and SegNeXt respectively. For super-resolution, EfficientViT provides up to 6.4x speedup over Restormer while delivering 0.11dB PSNR gain.

2. Main contribution

1. Introduced a new multi-scale linear attention module for efficient high-resolution dense prediction. It implements global receptive field and multi-scale learning while maintaining good hardware efficiency. To the best of our knowledge, our work is the first to demonstrate the effectiveness of linear attention for high-resolution dense predictions.

2. We design efficient vit, a new high-resolution series based on vision models and propose multi-scale linear attention modules.

3. EfficientViT has significant acceleration over previous SOTA models in semantic segmentation, super-resolution, segment anything and ImageNet classification on different hardware platforms (mobile CPU, edge GPU and cloud GPU).

3. Methodology

3.1 Multi-Scale Linear Attention (Multi-Scale Linear Attention)

Multi-scale linear attention simultaneously achieves global receptive field and multi-scale learning through hardware-efficient operations. Based on multi-scale linear attention, the authors propose a new vision transformer model EfficientVit for high-resolution dense prediction.

Motivation: From a performance perspective, global receptive fields and multi-scale learning are essential. Previous SOTA high-resolution dense prediction models provided strong performance by enabling these features, but did not provide good efficiency. The multi-scale linear attention module solves this problem by trading a slight performance loss for a significant efficiency gain.

Method: Use ReLU linear attention to achieve global receptive fields instead of heavy softmax attention.

Derivation of the formula of ReLU linear attention:

From the traditional softmax attention formula and the Relu attention similarity calculation function (the similarity calculation function is replaced by the Relu version), we can get:

According to the associative law of matrix multiplication, we can get:

Derivation of the final conclusion: As shown in formula (3), only need to calculate and\in \mathbb{R}^{d\times1}Once, they can be reused for each Query (the final solution of the multi-head attention mechanism to query unrelated problems?), thus only requiring O(N) computational cost and O(N) memory.

Limitations of ReLU linear attention: As shown in the figure below, the attention maps of softmax attention and ReLU linear attention. Due to the lack of nonlinear similarity functions, ReLU linear attention cannot generate concentrated attention maps and has a weak ability to capture local information. (ReLU linear attention shortcomings exposed)

solution:

1. To alleviate its limitations, we proposeenhancing ReLU linear attention with convolution. Specifically, depthwise convolutions are inserted into each FFN layer. As shown in the figure below, ReLU linear attention captures contextual information, and FFN + DWConv captures local information.

2. Aggregate (splice) adjacent Q/K/V token information into multi-scale tokens, to enhance the multi-scale learning ability of ReLU linear attention (multi-scale here is Refers to different scales in the channel direction, so aggregation can achieve multi-scale learning capabilities).

Specifically, merge all DWConv into a single DWConv group, merge all 1×1 Convs into a single 1×1 convolution group, the number of groups is 3 × #head, and the number of channels in each group is d. After obtaining the multi-scale token, perform ReLU linear attention on it to extract multi-scale global features. Finally, the features are concatenated along the head dimension and fed to the final linear layer to fuse the features.

(Essentially, the groups parameter in the nn.Conv2d() function is used to divide the input and output channels into several groups for convolution operations, and learn information of different scales in the channel direction.)

Q: What is the relationship between receptive field and attention mechanism?

A: The attention mechanism can capture long-distance dependencies by calculating the relationship between different positions, thereby expanding the receptive field and improving the perceptual ability of the network.

The code is as follows:

Lightweight multi-scale attention module

# Lightweight multi-scale attention
class LiteMLA(nn.Module):
    r"""Lightweight multi-scale linear attention"""

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        heads: int or None = None,
        heads_ratio: float = 1.0,
        dim=8,
        use_bias=False,
        norm=(None, "bn2d"),
        act_func=(None, None),
        kernel_func="relu",
        scales: tuple[int, ...] = (5,),
        eps=1.0e-15,
    ):
        super(LiteMLA, self).__init__()
        self.eps = eps
        heads = heads or int(in_channels // dim * heads_ratio)

        total_dim = heads * dim

        use_bias = val2tuple(use_bias, 2)
        norm = val2tuple(norm, 2)
        act_func = val2tuple(act_func, 2)

        self.dim = dim
        self.qkv = ConvLayer(
            in_channels,
            3 * total_dim,
            1,
            use_bias=use_bias[0],
            norm=norm[0],
            act_func=act_func[0],
        )
        self.aggreg = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Conv2d(
                        3 * total_dim,
                        3 * total_dim,
                        scale,
                        padding=get_same_padding(scale),
                        groups=3 * total_dim,
                        bias=use_bias[0],
                    ),
                    nn.Conv2d(3 * total_dim, 3 * total_dim, 1, groups=3 * heads, bias=use_bias[0]),
                )
                for scale in scales
            ]
        ) # The groups parameter in the nn.Conv2d() function refers to dividing the input and output channels into several groups for convolution operations
        self.kernel_func = build_act(kernel_func, inplace=False) # Relu activation function

        self.proj = ConvLayer(
            total_dim * (1 + len(scales)),
            out_channels,
            1,
            use_bias=use_bias[1],
            norm=norm[1],
            act_func=act_func[1],
        )

    @autocast(enabled=False)
    def relu_linear_att(self, qkv: torch.Tensor) -> torch.Tensor:
        B, _, H, W = list(qkv.size())

        if qkv.dtype == torch.float16:
            qkv = qkv.float()

        qkv = torch.reshape(
            qkv,
            (
                B,
                -1,
                3 * self.dim,
                H*W,
            ),
        )
        qkv = torch.transpose(qkv, -1, -2)
        q, k, v = (
            qkv[..., 0 : self.dim],
            qkv[..., self.dim : 2 * self.dim],
            qkv[..., 2 * self.dim :],
        )

        # lightweight linear attention
        q = self.kernel_func(q) # Perform relu activation
        k = self.kernel_func(k) # Perform relu activation

        # linear matmul
        trans_k = k.transpose(-1, -2)

        v = F.pad(v, (0, 1), mode="constant", value=1) # Perform dimension expansion
        kv = torch.matmul(trans_k, v) # Calculate according to the derivation formula
        out = torch.matmul(q, kv)
        out = out[..., :-1] / (out[..., -1:] + self.eps)

        out = torch.transpose(out, -1, -2)
        out = torch.reshape(out, (B, -1, H, W))
        return out

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # generate multi-scale q, k, v
        qkv = self.qkv(x) # Get Q, K, V, obtained by 1x1 convolution
        multi_scale_qkv = [qkv]
        for op in self.aggreg: # Convolution aggregation, learning multi-scale information on channels
            multi_scale_qkv.append(op(qkv))
        multi_scale_qkv = torch.cat(multi_scale_qkv, dim=1) # Q, K, V splicing

        out = self.relu_linear_att(multi_scale_qkv) # Re-divide into Q, K, V and feed into ReLU linear attention
        out = self.proj(out) # 1x1 convolution output, simulated linear layer

        return out

3.2 EfficientViT Architecture

As shown in FIG,

Backbone: It consists of an input layer and four stages. The feature map size gradually decreases and the number of channels gradually increases. Insert EfficientViT modules in stages 3 and 4. For downsampling, we use MBConv with stride 2.

Head (segmentation head): P2, P3 and P4 represent the output of stages 2, 3 and 4, forming a pyramid of feature maps. For simplicity and efficiency, we use 1x 1 convolution and standard upsampling operations (e.g., bilinear/bicubic upsampling) to match their spatial and channel sizes and fuse them via addition. Simple head design, which consists of several MBConv blocks and output layers (i.e., prediction and upsampling).

The code is as follows:

Backbone

class EfficientViTBackbone(nn.Module):
    # Backbone: input_stem + stage1 + stage2 + stage3 + stage4
    def __init__(
            self,
            width_list: list[int],
            depth_list: list[int],
            in_channels=3,
            dim=32,
            expand_ratio=4,
            norm="bn2d",
            act_func="hswish",
    ) -> None:
        super().__init__()

        self.width_list = []
        # input stem
        self.input_stem = [
            ConvLayer(
                in_channels=3,
                out_channels=width_list[0],
                stride=2,
                norm=norm,
                act_func=act_func,
            ) # 3x3 convolution -> downsample 2 times
        ]
        for _ in range(depth_list[0]):
            block = self.build_local_block( # Build DSConv module to capture local information
                in_channels=width_list[0],
                out_channels=width_list[0],
                stride=1,
                expand_ratio=1,
                norm=norm,
                act_func=act_func,
            )
            self.input_stem.append(ResidualBlock(block, IdentityLayer())) # Add residual
        in_channels = width_list[0]
        self.input_stem = OpSequential(self.input_stem) # Add each module in the input_stem stage to the ModuleList in order
        self.width_list.append(in_channels) #Add the number of channels of each module to width_list

        # stages
        self.stages = []

        # # # stages1
        for w, d in zip(width_list[1:3], depth_list[1:3]):
            stage = []
            for i in range(d):
                stride = 2 if i == 0 else 1
                block = self.build_local_block( # Build MBConv module to capture local information
                    in_channels=in_channels,
                    out_channels=w,
                    stride=stride,
                    expand_ratio=expand_ratio,
                    norm=norm,
                    act_func=act_func,
                )
                block = ResidualBlock(block, IdentityLayer() if stride == 1 else None) # Increase residual
                stage.append(block)
                in_channels = w
            self.stages.append(OpSequential(stage))
            self.width_list.append(in_channels)

        for w, d in zip(width_list[3:], depth_list[3:]):
            stage = []

            # # # stages2
            block = self.build_local_block( # Build MBConv module to capture local information
                in_channels=in_channels,
                out_channels=w,
                stride=2,
                expand_ratio=expand_ratio,
                norm=norm,
                act_func=act_func,
                fewer_norm=True,
            )
            stage.append(ResidualBlock(block, None))
            in_channels = w

            # # # stages3, 4
            for _ in range(d):
                stage.append(
                    EfficientViTBlock( # EfficientViTBlock module, multi-scale attention extraction contextual features
                        in_channels=in_channels,
                        dim=dim,
                        expand_ratio=expand_ratio,
                        norm=norm,
                        act_func=act_func,
                    )
                )
            self.stages.append(OpSequential(stage))
            self.width_list.append(in_channels)
        self.stages = nn.ModuleList(self.stages) # nn.ModuleList, used to store different modules and automatically add the parameters of each module to the network

    #Build DSConv or MBConv -> local information
    @staticmethod
    def build_local_block(
            in_channels: int,
            out_channels: int,
            stride: int,
            expand_ratio: float,
            norm: str,
            act_func: str,
            fewer_norm: bool = False,
    ) -> nn.Module:
        if expand_ratio == 1:
            block = DSConv( # DSConv module
                in_channels=in_channels,
                out_channels=out_channels,
                stride=stride,
                use_bias=(True, False) if fewer_norm else False,
                norm=(None, norm) if fewer_norm else norm,
                act_func=(act_func, None),
            )
        else:
            block = MBConv( # MBConv module, Mobile inverted residual bottleneck convolution -> 2 times downsampling
                in_channels=in_channels,
                out_channels=out_channels,
                stride=stride,
                expand_ratio=expand_ratio,
                use_bias=(True, True, False) if fewer_norm else False,
                norm=(None, None, norm) if fewer_norm else norm,
                act_func=(act_func, act_func, None),
            )
        return block

    def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
        output_dict = {"input": x}
        output_dict["stage0"] = x = self.input_stem(x)
        for stage_id, stage in enumerate(self.stages, 1): # Backbone of the network
            output_dict["stage%d" % stage_id] = x = stage(x)
        output_dict["stage_final"] = x
        return output_dict

DSConv module

class DSConv(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size=3,
        stride=1,
        use_bias=False,
        norm=("bn2d", "bn2d"),
        act_func=("relu6", None),
    ):
        super(DSConv, self).__init__()

        use_bias = val2tuple(use_bias, 2)
        norm = val2tuple(norm, 2)
        act_func = val2tuple(act_func, 2)

        self.depth_conv = ConvLayer(
            in_channels,
            in_channels,
            kernel_size,
            stride,
            groups=in_channels,
            norm=norm[0],
            act_func=act_func[0],
            use_bias=use_bias[0],
        )
        self.point_conv = ConvLayer(
            in_channels,
            out_channels,
            1,
            norm=norm[1],
            act_func=act_func[1],
            use_bias=use_bias[1],
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.depth_conv(x)
        x = self.point_conv(x)
        return x

MBConv module

#MBConv
classMBConv(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size=3,
        stride=1,
        mid_channels=None,
        expand_ratio=6,
        use_bias=False,
        norm=("bn2d", "bn2d", "bn2d"),
        act_func=("relu6", "relu6", None),
    ):
        super(MBConv, self).__init__()

        use_bias = val2tuple(use_bias, 3)
        norm = val2tuple(norm, 3)
        act_func = val2tuple(act_func, 3)
        mid_channels = mid_channels or round(in_channels * expand_ratio)

        self.inverted_conv = ConvLayer(
            in_channels,
            mid_channels,
            1,
            stride=1,
            norm=norm[0],
            act_func=act_func[0],
            use_bias=use_bias[0],
        )
        self.depth_conv = ConvLayer(
            mid_channels,
            mid_channels,
            kernel_size,
            stride=stride,
            groups=mid_channels,
            norm=norm[1],
            act_func=act_func[1],
            use_bias=use_bias[1],
        )
        self.point_conv = ConvLayer(
            mid_channels,
            out_channels,
            1,
            norm=norm[2],
            act_func=act_func[2],
            use_bias=use_bias[2],
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.inverted_conv(x) #512
        x = self.depth_conv(x) #512
        x = self.point_conv(x) #256
        return x

EfficientViTBlock module

# EfficientViTBlock module -> Extract contextual features
class EfficientViTBlock(nn.Module):
    def __init__(
        self,
        in_channels: int,
        heads_ratio: float = 1.0,
        dim=32,
        expand_ratio: float = 4,
        norm="bn2d",
        act_func="hswish",
    ):
        super(EfficientViTBlock, self).__init__()
        self.context_module = ResidualBlock(
            LiteMLA( #Lightweight multi-scale attention
                in_channels=in_channels,
                out_channels=in_channels,
                heads_ratio=heads_ratio,
                dim=dim,
                norm=(None, norm),
            ),
            IdentityLayer(),
        )
        local_module = MBConv(
            in_channels=in_channels,
            out_channels=in_channels,
            expand_ratio=expand_ratio,
            use_bias=(True, True, False),
            norm=(None, None, norm),
            act_func=(act_func, act_func, None),
        )
        self.local_module = ResidualBlock(local_module, IdentityLayer()) # Add residual connection

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.context_module(x) #Lightweight multi-scale attention -> global context features
        x = self.local_module(x) # Depth convolution -> local features
        return x

4. Experiment

Datasets: Cityscapes and ADE20K datasets.

Evaluation indicators: mIoU, Params and MAC (multiply, accumulate and accumulate operations).

4.1 Ablation Study

(1) Performance test of EfficientViT module

mIoU and MAC are measured on Cityscapes with input resolution 1024×2048. Rescale the width of the models so that they have the same MAC. As shown in the above table, multi-scale learning and global receptive fields are crucial to obtain good semantic segmentation performance.

(2) Backbone performance comparison on ImageNet

EfficientViT-L2-r384 achieves a top-1 accuracy of 86.0 on ImageNet, providing a +0.3 accuracy gain over EfficientNetV 2-L and a 2.6x speedup on the A100 GPU.

4.2 Semantic Segmentation Experiment

Comparison with advanced semantic segmentation models on the Cityscapes dataset.

Compared to SegFormer, EfficientViT achieves up to 13x MAC count savings and up to 8.8x latency reduction on higher mIoU edge GPUs (Jetson AGX Orin). Compared to SegNeXt, EfficientViT delivers up to 2.0x MAC reduction and 3.8x speedup on edge GPUs while maintaining higher mIoU.

5. Summary

1. This paper introduces a lightweight multi-scale attention module for effective architecture design of high-resolution dense prediction, which simultaneously implements global receptive field and multi-scale learning with lightweight and hardware-efficient operations, thereby Provides significant acceleration on a variety of hardware devices without incurring performance penalty over SOTA high-resolution dense prediction models.

2. Multi-scale linear attention, using ReLU linear attention to achieve the global receptive field, capturing local information through FFN + DWConv and convolution aggregation to capture multi-scale information, thereby overcoming the shortcomings caused by the lightweight ReLU linear attention.