When transferring pytorch to onnx, grid_sample does not have a corresponding operator, so the transfer cannot be successful. The solutions on the Internet are all four-dimensional (4D) data input, but my input is five-dimensional (5D). I looked for the code and made modifications based on some codes. . The code does not consider performance, only normal export.
import torch from torch import nn import torch.nn.functional as F def grid_sample_3d(input, grid,align_corners): N, C, ID, IH, IW = input.shape _, D, H, W, _ = grid.shape ix = grid[..., 0] iy = grid[..., 1] iz = grid[..., 2] if(align_corners == False): ix = ((ix + 1) * IW - 1) / 2 iy = ((iy + 1) * IH - 1) / 2 iz = ((iz + 1) * ID - 1) / 2 else: ix = ((ix + 1) / 2) * (IW - 1) iy = ((iy + 1) / 2) * (IH - 1) iz = ((iz + 1) / 2) * (ID - 1) with torch.no_grad(): ix_tnw = torch.floor(ix); iy_tnw = torch.floor(iy); iz_tnw = torch.floor(iz); ix_tne = ix_tnw + 1; iy_tne = iy_tnw; iz_tne = iz_tnw; ix_tsw = ix_tnw; iy_tsw = iy_tnw + 1; iz_tsw = iz_tnw; ix_tse = ix_tnw + 1; iy_tse = iy_tnw + 1; iz_tse = iz_tnw; ix_bnw = ix_tnw; iy_bnw = iy_tnw; iz_bnw = iz_tnw + 1; ix_bne = ix_tnw + 1; iy_bne = iy_tnw; iz_bne = iz_tnw + 1; ix_bsw = ix_tnw; iy_bsw = iy_tnw + 1; iz_bsw = iz_tnw + 1; ix_bse = ix_tnw + 1; iy_bse = iy_tnw + 1; iz_bse = iz_tnw + 1; tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz); tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz); tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz); tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz); bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse); bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw); bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne); bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw); with torch.no_grad(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if (ix_tnw.device.type != 'cpu'): print(ix_tnw.device.type) print("---------------------") zero_tensor = torch.tensor(0).float().to(device) iw_tensor = torch.tensor(IW - 1).float().to(device) ih_tensor = torch.tensor(IH - 1).float().to(device) id_tensor = torch.tensor(ID - 1).float().to(device) else: zero_tensor = torch.tensor(0).float() iw_tensor = torch.tensor(IW - 1).float() ih_tensor = torch.tensor(IH - 1).float() id_tensor = torch.tensor(ID - 1).float() ix_tnw = torch.where(ix_tnw < 0, zero_tensor, ix_tnw.float()) ix_tnw = torch.where(ix_tnw > IW - 1, iw_tensor, ix_tnw.float()) iy_tnw = torch.where(iy_tnw < 0, zero_tensor, iy_tnw.float()) iy_tnw = torch.where(iy_tnw > IH - 1, ih_tensor, iy_tnw.float()) iz_tnw = torch.where(iz_tnw < 0, zero_tensor, iz_tnw.float()) iz_tnw = torch.where(iz_tnw > ID - 1, id_tensor, iz_tnw.float()) ix_tne = torch.where(ix_tne < 0, zero_tensor, ix_tne.float()) ix_tne = torch.where(ix_tne > IW - 1, iw_tensor, ix_tne.float()) iy_tne = torch.where(iy_tne < 0, zero_tensor, iy_tne.float()) iy_tne = torch.where(iy_tne > IH - 1, ih_tensor, iy_tne.float()) iz_tne = torch.where(iz_tne < 0, zero_tensor, iz_tne.float()) iz_tne = torch.where(iz_tne > ID - 1, id_tensor, iz_tne.float()) ix_tsw = torch.where(ix_tsw < 0, zero_tensor, ix_tsw.float()) ix_tsw = torch.where(ix_tsw > IW - 1, iw_tensor, ix_tsw.float()) iy_tsw = torch.where(iy_tsw < 0, zero_tensor, iy_tsw.float()) iy_tsw = torch.where(iy_tsw > IH - 1, ih_tensor, iy_tsw.float()) iz_tsw = torch.where(iz_tsw < 0, zero_tensor, iz_tsw.float()) iz_tsw = torch.where(iz_tsw > ID - 1, id_tensor, iz_tsw.float()) ix_tse = torch.where(ix_tse < 0, zero_tensor, ix_tse.float()) ix_tse = torch.where(ix_tse > IW - 1, iw_tensor, ix_tse.float()) iy_tse = torch.where(iy_tse < 0, zero_tensor, iy_tse.float()) iy_tse = torch.where(iy_tse > IH - 1, ih_tensor, iy_tse.float()) iz_tse = torch.where(iz_tse < 0, zero_tensor, iz_tse.float()) iz_tse = torch.where(iz_tse > ID - 1, id_tensor, iz_tse.float()) ix_bnw = torch.where(ix_bnw < 0, zero_tensor, ix_bnw.float()) ix_bnw = torch.where(ix_bnw > IW - 1, iw_tensor, ix_bnw.float()) iy_bnw = torch.where(iy_bnw < 0, zero_tensor, iy_bnw.float()) iy_bnw = torch.where(iy_bnw > IH - 1, ih_tensor, iy_bnw.float()) iz_bnw = torch.where(iz_bnw < 0, zero_tensor, iz_bnw.float()) iz_bnw = torch.where(iz_bnw > ID - 1, id_tensor, iz_bnw.float()) ix_bne = torch.where(ix_bne < 0, zero_tensor, ix_bne.float()) ix_bne = torch.where(ix_bne > IW - 1, iw_tensor, ix_bne.float()) iy_bne = torch.where(iy_bne < 0, zero_tensor, iy_bne.float()) iy_bne = torch.where(iy_bne > IH - 1, ih_tensor, iy_bne.float()) iz_bne = torch.where(iz_bne < 0, zero_tensor, iz_bne.float()) iz_bne = torch.where(iz_bne > ID - 1, id_tensor, iz_bne.float()) ix_bsw = torch.where(ix_bsw < 0, zero_tensor, ix_bsw.float()) ix_bsw = torch.where(ix_bsw > IW - 1, iw_tensor, ix_bsw.float()) iy_bsw = torch.where(iy_bsw < 0, zero_tensor, iy_bsw.float()) iy_bsw = torch.where(iy_bsw > IH - 1, ih_tensor, iy_bsw.float()) iz_bsw = torch.where(iz_bsw < 0, zero_tensor, iz_bsw.float()) iz_bsw = torch.where(iz_bsw > ID - 1, id_tensor, iz_bsw.float()) ix_bse = torch.where(ix_bse < 0, zero_tensor, ix_bse.float()) ix_bse = torch.where(ix_bse > IW - 1, iw_tensor, ix_bse.float()) iy_bse = torch.where(iy_bse < 0, zero_tensor, iy_bse.float()) iy_bse = torch.where(iy_bse > IH - 1, ih_tensor, iy_bse.float()) iz_bse = torch.where(iz_bse < 0, zero_tensor, iz_bse.float()) iz_bse = torch.where(iz_bse > ID - 1, id_tensor, iz_bse.float()) input = input.view(N, C, ID * IH * IW) tnw_val = torch.gather(input, 2, (iz_tnw * IW * IH + iy_tnw * IW + ix_tnw).long().view(N, 1, D * H * W).repeat(1, C, 1)) tne_val = torch.gather(input, 2, (iz_tne * IW * IH + iy_tne * IW + ix_tne).long().view(N, 1, D * H * W).repeat(1, C, 1)) tsw_val = torch.gather(input, 2, (iz_tsw * IW * IH + iy_tsw * IW + ix_tsw).long().view(N, 1, D * H * W).repeat(1, C, 1)) tse_val = torch.gather(input, 2, (iz_tse * IW * IH + iy_tse * IW + ix_tse).long().view(N, 1, D * H * W).repeat(1, C, 1)) bnw_val = torch.gather(input, 2, (iz_bnw * IW * IH + iy_bnw * IW + ix_bnw).long().view(N, 1, D * H * W).repeat(1, C, 1)) bne_val = torch.gather(input, 2, (iz_bne * IW * IH + iy_bne * IW + ix_bne).long().view(N, 1, D * H * W).repeat(1, C, 1)) bsw_val = torch.gather(input, 2, (iz_bsw * IW * IH + iy_bsw * IW + ix_bsw).long().view(N, 1, D * H * W).repeat(1, C, 1)) bse_val = torch.gather(input, 2, (iz_bse * IW * IH + iy_bse * IW + ix_bse).long().view(N, 1, D * H * W).repeat(1, C, 1)) out_val = (tnw_val.view(N, C, D, H, W) * tnw.view(N, 1, D, H, W) + tne_val.view(N, C, D, H, W) * tne.view(N, 1, D, H, W) + tsw_val.view(N, C, D, H, W) * tsw.view(N, 1, D, H, W) + tse_val.view(N, C, D, H, W) * tse.view(N, 1, D, H, W) + bnw_val.view(N, C, D, H, W) * bnw.view(N, 1, D, H, W) + bne_val.view(N, C, D, H, W) * bne.view(N, 1, D, H, W) + bsw_val.view(N, C, D, H, W) * bsw.view(N, 1, D, H, W) + bse_val.view(N, C, D, H, W) * bse.view(N, 1, D, H, W)) return out_val if __name__ == "__main__": data = torch.rand(1,1,256,104,80) grid = torch.rand(1,256,104,80,3) ret = F.grid_sample(data,grid,align_corners=False).squeeze(1) print(ret) ret2 = grid_sample_3d(data,grid,False).squeeze(1) print(ret2)
grateful
https://github.com/pytorch/pytorch/issues/34704