SwinV2 traced inference (Jetson)

I’m trying to use jit compiled swin v2 on Nvidia Jetson but have error after one inference step on CUDA and have no problems on CPU.

Steps to reproduce the behavior:
Prepare

from timm.models.swin_transformer_v2 import swinv2_tiny_window16_256
model = swinv2_tiny_window16_256(pretrained=True)
inputs = torch.randn(1, 3, 256, 256)

jit_file = "swin_v2.pt"
jit_model = torch.jit.trace(model, (inputs,))
torch.jit.save(jit_file)

Test

model =torch.jit.load(jit_file, map_location="cuda")
inputs = torch.randn(1, 3, 256, 256).cuda()
with torch.no_grad():
    model(inputs)

The tracing itself works and I should be able to inference model, but after one call I see error:

RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
  File "code/__torch__/timm/models/swin_transformer_v2/___torch_mangle_569.py", line 39, in forward
    denom0 = torch.expand_as(_20, input1)
    _21 = torch.transpose(torch.div(input1, denom0), -2, -1)
    attn = torch.matmul(_19, _21)
           ~~~~~~~~~~~~ <--- HERE
    attn0 = torch.mul(attn, CONSTANTS.c7)
    input2 = torch.linear(CONSTANTS.c8, CONSTANTS.c9, CONSTANTS.c10)
    
    RuntimeError: The size of tensor a (3) must match the size of tensor b (32) at non-singleton dimension 1
  • OS: Nvidia Jetson Orin –
  • timm == 0.9.8
  • PyTorch==2.0.0 with CUDA==11.4

Does anyone have idea what might be the problem?

Hello,

Welcome to the NVIDIA Developer forums! I am moving your topic to the Jetson Orin category for better visibility.

Hi,

The error looks more like triggered from the user space.
Have you tried the same script on other platforms like x86?

Thanks.

Hi @AastaLLL

Yes, it works on the x64.
I was able to trace only on x64 platform but I used the same pytorch version what is in the nvcr.io/nvidia/l4t-pytorch:r35.2.1-pth2.0-py3 container.

I was trying to do the trace on Jetson Orin, but see the tracing error.

Hi,

Which PyTorch version do you use on the desktop?
Is it v2.0 as well?

Thanks.

Hi,

I used 2.0.0 (same result but CUDA is newer) and I downgraded pytorch to 1.12.1 so the version of CUDA was not newer.
So:
x64: pytorch==1.12.1+cu113
("https://download.pytorch.org/whl/cu113/torch-1.12.1%2Bcu113-cp310-cp310 linux_x86_64.whl#sha256=be682ef94e37cd3f0768b8ce6106705410189df2c365d65d7bc1bebb302d84cd")

Jetson Orin: pytorch: nvcr.io/nvidia/l4t-pytorch:r35.2.1-pth2.0-py3

There is no update from you for a period, assuming this is not an issue any more.
Hence we are closing this topic. If need further support, please open a new one.
Thanks

Hi,

Just want to confirm first.
When you tried 2.0.0 on x86, does it work or meet the same error as Orin?

Thanks