Onnx export swin transformer

1. Configure the swin transformer environment according to the repo.

https://github.com/microsoft/Swin-Transformer

2. Create the file export.py in the repo directory.

run

`python export.py –eval –cfg configs/swin/swin_based_patch4_window7_224.yaml –resume ../weights/swin_tiny_patch4_window7_224.pth –data-path data/ –lock_rank 0`

# ----------------------------------------------- ----------
#SwinTransformer
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ze Liu
#------------------------------------------------ -------

import argparse

import torch

from config import get_config
from models import build_model

import onnx
import onnxsim

def parse_option():
    parser = argparse.ArgumentParser('Swin Transformer training and evaluation script', add_help=False)
    parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', )
    parser.add_argument(
        "--opts",
        help="Modify config options by adding 'KEY VALUE' pairs. ",
        default=None,
        nargs=' + ',
    )

    # easy config modification
    parser.add_argument('--batch-size', type=int, help="batch size for single GPU")
    parser.add_argument('--data-path', type=str, help='path to dataset')
    parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset')
    parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],
                        help='no: no cache, '
                             'full: cache all data, '
                             'part: sharding the dataset into nonoverlapping pieces and only cache one piece')
    parser.add_argument('--pretrained',
                        help='pretrained weight from checkpoint, could be imagenet22k pretrained weight')
    parser.add_argument('--resume', help='resume from checkpoint')
    parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
    parser.add_argument('--use-checkpoint', action='store_true',
                        help="whether to use gradient checkpointing to save memory")
    parser.add_argument('--disable_amp', action='store_true', help='Disable pytorch amp')
    parser.add_argument('--amp-opt-level', type=str, choices=['O0', 'O1', 'O2'],
                        help='mixed precision opt level, if O0, no amp is used (deprecated!)')
    parser.add_argument('--output', default='output', type=str, metavar='PATH',
                        help='root of output folder, the full path is <output>/<model_name>/<tag> (default: output)')
    parser.add_argument('--tag', help='tag of experiment')
    parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
    parser.add_argument('--throughput', action='store_true', help='Test throughput only')

    # distributed training
    parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel')

    # for acceleration
    parser.add_argument('--fused_window_process', action='store_true',
                        help='Fused window shift & amp; window partition, similar for reversed part.')
    parser.add_argument('--fused_layernorm', action='store_true', help='Use fused layernorm.')
    ## overwrite optimizer in config (*.yaml) if specified, e.g., fused_adam/fused_lamb
    parser.add_argument('--optim', type=str,
                        help='overwrite optimizer if provided, can be adamw/sgd/fused_adam/fused_lamb.')

    args, unparsed = parser.parse_known_args()

    config = get_config(args)

    return args, config

def export_norm_onnx(model, file, input):
    torch.onnx.export(
        model = model,
        args = (input,),
        f = file,
        input_names = ["input0"],
        output_names = ["output0"],
        opset_version = 9)

    print("Finished normal onnx export")

    model_onnx = onnx.load(file)

    # Check the imported onnx model
    onnx.checker.check_model(model_onnx)

    # Use onnx-simplifier to simplify onnx.
    print(f"Simplifying with onnx-simplifier {onnxsim.__version__}...")
    model_onnx, check = onnxsim.simplify(model_onnx)
    assert check, "assert check failed"
    onnx.save(model_onnx, file)

def main(config):
    model = build_model(config)

    input = torch.rand(1, 3, 224, 224)

    model.eval()
    # export_norm_onnx(model, "../models/swin-tiny-after-simplify-opset9.rnnx", input)
    export_norm_onnx(model, "../models/swin-tiny-after-simplify-opset12.onnx", input)
    # export_norm_onnx(model, "../models/swin-tiny-after-simplify-opset17.onnx", input)


if __name__ == '__main__':
    args, config = parse_option()
    main(config)

3. Solve the problem that the operator is not supported.

Running the above program will generate an error that the roll operator is not supported.

The solution is:

Find symbolic_opset12.py under the torch module in the configuration environment. For example, mine is:

/home/dkwang/Documents/miniconda3/envs/openmmlab/lib/python3.8/site-packages/torch/onnx/symbolic_opset12.py

Add the following registration information

@parse_args('v', 'is', 'is')
def roll(g, self, shifts, dims):
    assert len(shifts) == len(dims)

    result = self
    for i in range(len(shifts)):
        shapes = []
        shape = sym_help._slice_helper(g,
                                       result,
                                       axes=[dims[i]],
                                       starts=[-shifts[i]],
                                       ends=[maxsize])
        shapes.append(shape)
        shape = sym_help._slice_helper(g,
                                       result,
                                       axes=[dims[i]],
                                       starts=[0],
                                       ends=[-shifts[i]])
        shapes.append(shape)
        result = g.op("Concat", *shapes, axis_i=dims[i])

    return result

Then run the export program as above.

4. Simplify Layer Normalization

Torch 2.0 supports onnx 17 and supports direct export of LayerNormalization. Reconfigure the environment to torch2.0.

Open the export results with netron

The knowledge points of the article match the official knowledge files, and you can further learn relevant knowledge. Python entry skill treeHomepageOverview 388851 people are learning the system