Solve the problem of converting pytorch model to onnx model, and the grid_sample function input parameter is 5-dimensional and the export fails

When transferring pytorch to onnx, grid_sample does not have a corresponding operator, so the transfer cannot be successful. The solutions on the Internet are all four-dimensional (4D) data input, but my input is five-dimensional (5D). I looked for the code and made modifications based on some codes. . The code does not consider performance, only normal export.

import torch
from torch import nn
import torch.nn.functional as F

def grid_sample_3d(input, grid,align_corners):
    N, C, ID, IH, IW = input.shape
    _, D, H, W, _ = grid.shape

    ix = grid[..., 0]
    iy = grid[..., 1]
    iz = grid[..., 2]

    if(align_corners == False):
        ix = ((ix + 1) * IW - 1) / 2
        iy = ((iy + 1) * IH - 1) / 2
        iz = ((iz + 1) * ID - 1) / 2
    else:
        ix = ((ix + 1) / 2) * (IW - 1)
        iy = ((iy + 1) / 2) * (IH - 1)
        iz = ((iz + 1) / 2) * (ID - 1)

    with torch.no_grad():
        
        ix_tnw = torch.floor(ix);
        iy_tnw = torch.floor(iy);
        iz_tnw = torch.floor(iz);

        ix_tne = ix_tnw + 1;
        iy_tne = iy_tnw;
        iz_tne = iz_tnw;

        ix_tsw = ix_tnw;
        iy_tsw = iy_tnw + 1;
        iz_tsw = iz_tnw;

        ix_tse = ix_tnw + 1;
        iy_tse = iy_tnw + 1;
        iz_tse = iz_tnw;

        ix_bnw = ix_tnw;
        iy_bnw = iy_tnw;
        iz_bnw = iz_tnw + 1;

        ix_bne = ix_tnw + 1;
        iy_bne = iy_tnw;
        iz_bne = iz_tnw + 1;

        ix_bsw = ix_tnw;
        iy_bsw = iy_tnw + 1;
        iz_bsw = iz_tnw + 1;

        ix_bse = ix_tnw + 1;
        iy_bse = iy_tnw + 1;
        iz_bse = iz_tnw + 1;

    tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz);
    tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz);
    tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz);
    tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz);
    bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse);
    bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw);
    bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne);
    bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw);


    with torch.no_grad():
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        if (ix_tnw.device.type != 'cpu'):
            print(ix_tnw.device.type)
            print("---------------------")
            zero_tensor = torch.tensor(0).float().to(device)
            iw_tensor = torch.tensor(IW - 1).float().to(device)
            ih_tensor = torch.tensor(IH - 1).float().to(device)
            id_tensor = torch.tensor(ID - 1).float().to(device)
        else:
            zero_tensor = torch.tensor(0).float()
            iw_tensor = torch.tensor(IW - 1).float()
            ih_tensor = torch.tensor(IH - 1).float()
            id_tensor = torch.tensor(ID - 1).float()

        ix_tnw = torch.where(ix_tnw < 0, zero_tensor, ix_tnw.float())
        ix_tnw = torch.where(ix_tnw > IW - 1, iw_tensor, ix_tnw.float())

        iy_tnw = torch.where(iy_tnw < 0, zero_tensor, iy_tnw.float())
        iy_tnw = torch.where(iy_tnw > IH - 1, ih_tensor, iy_tnw.float())
          
        iz_tnw = torch.where(iz_tnw < 0, zero_tensor, iz_tnw.float())
        iz_tnw = torch.where(iz_tnw > ID - 1, id_tensor, iz_tnw.float())
          
        ix_tne = torch.where(ix_tne < 0, zero_tensor, ix_tne.float())
        ix_tne = torch.where(ix_tne > IW - 1, iw_tensor, ix_tne.float())
          
        iy_tne = torch.where(iy_tne < 0, zero_tensor, iy_tne.float())
        iy_tne = torch.where(iy_tne > IH - 1, ih_tensor, iy_tne.float())
          
        iz_tne = torch.where(iz_tne < 0, zero_tensor, iz_tne.float())
        iz_tne = torch.where(iz_tne > ID - 1, id_tensor, iz_tne.float())
          
        ix_tsw = torch.where(ix_tsw < 0, zero_tensor, ix_tsw.float())
        ix_tsw = torch.where(ix_tsw > IW - 1, iw_tensor, ix_tsw.float())
          
        iy_tsw = torch.where(iy_tsw < 0, zero_tensor, iy_tsw.float())
        iy_tsw = torch.where(iy_tsw > IH - 1, ih_tensor, iy_tsw.float())
          
        iz_tsw = torch.where(iz_tsw < 0, zero_tensor, iz_tsw.float())
        iz_tsw = torch.where(iz_tsw > ID - 1, id_tensor, iz_tsw.float())
          
        ix_tse = torch.where(ix_tse < 0, zero_tensor, ix_tse.float())
        ix_tse = torch.where(ix_tse > IW - 1, iw_tensor, ix_tse.float())
          
        iy_tse = torch.where(iy_tse < 0, zero_tensor, iy_tse.float())
        iy_tse = torch.where(iy_tse > IH - 1, ih_tensor, iy_tse.float())
          
        iz_tse = torch.where(iz_tse < 0, zero_tensor, iz_tse.float())
        iz_tse = torch.where(iz_tse > ID - 1, id_tensor, iz_tse.float())

        ix_bnw = torch.where(ix_bnw < 0, zero_tensor, ix_bnw.float())
        ix_bnw = torch.where(ix_bnw > IW - 1, iw_tensor, ix_bnw.float())

        iy_bnw = torch.where(iy_bnw < 0, zero_tensor, iy_bnw.float())
        iy_bnw = torch.where(iy_bnw > IH - 1, ih_tensor, iy_bnw.float())

        iz_bnw = torch.where(iz_bnw < 0, zero_tensor, iz_bnw.float())
        iz_bnw = torch.where(iz_bnw > ID - 1, id_tensor, iz_bnw.float())

        ix_bne = torch.where(ix_bne < 0, zero_tensor, ix_bne.float())
        ix_bne = torch.where(ix_bne > IW - 1, iw_tensor, ix_bne.float())
        
        iy_bne = torch.where(iy_bne < 0, zero_tensor, iy_bne.float())
        iy_bne = torch.where(iy_bne > IH - 1, ih_tensor, iy_bne.float())

        iz_bne = torch.where(iz_bne < 0, zero_tensor, iz_bne.float())
        iz_bne = torch.where(iz_bne > ID - 1, id_tensor, iz_bne.float())

        ix_bsw = torch.where(ix_bsw < 0, zero_tensor, ix_bsw.float())
        ix_bsw = torch.where(ix_bsw > IW - 1, iw_tensor, ix_bsw.float())

        iy_bsw = torch.where(iy_bsw < 0, zero_tensor, iy_bsw.float())
        iy_bsw = torch.where(iy_bsw > IH - 1, ih_tensor, iy_bsw.float())

        iz_bsw = torch.where(iz_bsw < 0, zero_tensor, iz_bsw.float())
        iz_bsw = torch.where(iz_bsw > ID - 1, id_tensor, iz_bsw.float())

        ix_bse = torch.where(ix_bse < 0, zero_tensor, ix_bse.float())
        ix_bse = torch.where(ix_bse > IW - 1, iw_tensor, ix_bse.float())

        iy_bse = torch.where(iy_bse < 0, zero_tensor, iy_bse.float())
        iy_bse = torch.where(iy_bse > IH - 1, ih_tensor, iy_bse.float())

        iz_bse = torch.where(iz_bse < 0, zero_tensor, iz_bse.float())
        iz_bse = torch.where(iz_bse > ID - 1, id_tensor, iz_bse.float())


    input = input.view(N, C, ID * IH * IW)

    tnw_val = torch.gather(input, 2, (iz_tnw * IW * IH + iy_tnw * IW + ix_tnw).long().view(N, 1, D * H * W).repeat(1, C, 1))
    tne_val = torch.gather(input, 2, (iz_tne * IW * IH + iy_tne * IW + ix_tne).long().view(N, 1, D * H * W).repeat(1, C, 1))
    tsw_val = torch.gather(input, 2, (iz_tsw * IW * IH + iy_tsw * IW + ix_tsw).long().view(N, 1, D * H * W).repeat(1, C, 1))
    tse_val = torch.gather(input, 2, (iz_tse * IW * IH + iy_tse * IW + ix_tse).long().view(N, 1, D * H * W).repeat(1, C, 1))
    bnw_val = torch.gather(input, 2, (iz_bnw * IW * IH + iy_bnw * IW + ix_bnw).long().view(N, 1, D * H * W).repeat(1, C, 1))
    bne_val = torch.gather(input, 2, (iz_bne * IW * IH + iy_bne * IW + ix_bne).long().view(N, 1, D * H * W).repeat(1, C, 1))
    bsw_val = torch.gather(input, 2, (iz_bsw * IW * IH + iy_bsw * IW + ix_bsw).long().view(N, 1, D * H * W).repeat(1, C, 1))
    bse_val = torch.gather(input, 2, (iz_bse * IW * IH + iy_bse * IW + ix_bse).long().view(N, 1, D * H * W).repeat(1, C, 1))

    out_val = (tnw_val.view(N, C, D, H, W) * tnw.view(N, 1, D, H, W) +
               tne_val.view(N, C, D, H, W) * tne.view(N, 1, D, H, W) +
               tsw_val.view(N, C, D, H, W) * tsw.view(N, 1, D, H, W) +
               tse_val.view(N, C, D, H, W) * tse.view(N, 1, D, H, W) +
               bnw_val.view(N, C, D, H, W) * bnw.view(N, 1, D, H, W) +
               bne_val.view(N, C, D, H, W) * bne.view(N, 1, D, H, W) +
               bsw_val.view(N, C, D, H, W) * bsw.view(N, 1, D, H, W) +
               bse_val.view(N, C, D, H, W) * bse.view(N, 1, D, H, W))

    return out_val


if __name__ == "__main__":
    data = torch.rand(1,1,256,104,80)
    grid = torch.rand(1,256,104,80,3)

    ret = F.grid_sample(data,grid,align_corners=False).squeeze(1)
    print(ret)
    ret2 = grid_sample_3d(data,grid,False).squeeze(1)
    print(ret2)

grateful
https://github.com/pytorch/pytorch/issues/34704