Simple 2 layer U-Net breaks TensorRT conversion

Description

I created the first layer of a simple U-Net in PyTorch, using only nn.Conv2d(), nn.ConvTranspose2d() and F.relu() layers and torch.cat(). Converting to ONNX with setop version 7 and then to a TensorRT engine using trtexec on a Jetson Xavier with JetPack 4.5 works fine! But then, when I add a second layer to the U-Net, using the exact same layers as before, trtexec reports an error.

Environment

Platform: Jetson XAVIER Dev Kit
Jetpack Version: 4.5

Relevant Files

U-Net 1-layer, which works fine

self.conv1 = nn.Conv2d(4, 80, (3, 3), 2, 1, bias=True)
self.conv10 = nn.Conv2d(80 + 4, 1, (3, 3), 1, 1, bias=True)
self.conv_transpose10 = nn.ConvTranspose2d(80, 80, kernel_size=2, stride=2)

x1 = F.relu(self.conv1(x))
xup = self.conv_transpose10(x1, output_size=x.size()[2:])
d = F.relu(self.conv10(torch.cat((xup, x), 1)))

U-Net 2-layers, fails

self.conv1 = nn.Conv2d(4, 80, (3, 3), 2, 1, bias=True)
self.conv2 = nn.Conv2d(80, 80, (3, 3), 2, 1, bias=True)
self.conv9 = nn.Conv2d(80 + 80, 32, (3, 3), 1, 1, bias=True)
self.conv10 = nn.Conv2d(32 + 4, 1, (3, 3), 1, 1, bias=True)
self.conv_transpose9 = nn.ConvTranspose2d(80, 80, kernel_size=2, stride=2)
self.conv_transpose10 = nn.ConvTranspose2d(32, 32, kernel_size=2, stride=2)

x1 = F.relu(self.conv1(x))
x2 = F.relu(self.conv2(x1))
xup = self.conv_transpose9(x2, output_size=x1.size()[2:])
x9 = F.relu(self.conv9(torch.cat((xup, x1), 1)))
xup = self.conv_transpose10(x9, output_size=x.size()[2:])
d = F.relu(self.conv10(torch.cat((xup, x), 1)))

Steps To Reproduce

Conversion from PyTorch to ONNX model.

On Jetson Xavier with JetPack 4.5:

/usr/src/tensorrt/bin/trtexec --buildOnly --fp16 --useDLACore=1 --allowGPUFallback=false --workspace=4096 --explicitBatch --onnx=model.onnx --saveEngine=model_engine.trt


Input filename: model.onnx
ONNX IR version: 0.0.4
Opset version: 7
Producer name: pytorch
Producer version: 1.3
Domain:
Model version: 0
Doc string:

[09/15/2021-15:07:45] [I] [TRT]
[09/15/2021-15:07:45] [I] [TRT] --------------- Layers running on DLA:
[09/15/2021-15:07:45] [I] [TRT] {(Unnamed Layer* 0) [Convolution],(Unnamed Layer* 1) [Activation],(Unnamed Layer* 2) [Convolution],(Unnamed Layer* 3) [Activation],(Unnamed Layer* 4) [Deconvolution],(Unnamed Layer* 5) [Concatenation],(Unnamed Layer* 6) [Convolution],(Unnamed Layer* 7) [Activation],(Unnamed Layer* 8) [Deconvolution]}, {(Unnamed Layer* 10) [Convolution],(Unnamed Layer* 11) [Activation]},
[09/15/2021-15:07:45] [I] [TRT] --------------- Layers running on GPU:
[09/15/2021-15:07:45] [I] [TRT] 21 copy, input copy,
[09/15/2021-15:07:48] [W] [TRT] DLA Node compilation Failed.
[09/15/2021-15:07:48] [W] [TRT] DLA Node compilation Failed.
[09/15/2021-15:07:48] [E] [TRT] Try increasing the workspace size with IBuilderConfig::setMaxWorkspaceSize() if using IBuilder::buildEngineWithConfig, or IBuilder::setMaxWorkspaceSize() if using IBuilder::buildCudaEngine.
[09/15/2021-15:07:48] [E] [TRT] …/builder/tacticOptimizer.cpp (1715) - TRTInternal Error in computeCosts: 0 (Could not find any implementation for node {(Unnamed Layer* 0) [Convolution],(Unnamed Layer* 1) [Activation],(Unnamed Layer* 2) [Convolution],(Unnamed Layer* 3) [Activation],(Unnamed Layer* 4) [Deconvolution],(Unnamed Layer* 5) [Concatenation],(Unnamed Layer* 6) [Convolution],(Unnamed Layer* 7) [Activation],(Unnamed Layer* 8) [Deconvolution]}.)
[09/15/2021-15:07:48] [E] [TRT] …/builder/tacticOptimizer.cpp (1715) - TRTInternal Error in computeCosts: 0 (Could not find any implementation for node {(Unnamed Layer* 0) [Convolution],(Unnamed Layer* 1) [Activation],(Unnamed Layer* 2) [Convolution],(Unnamed Layer* 3) [Activation],(Unnamed Layer* 4) [Deconvolution],(Unnamed Layer* 5) [Concatenation],(Unnamed Layer* 6) [Convolution],(Unnamed Layer* 7) [Activation],(Unnamed Layer* 8) [Deconvolution]}.)
[09/15/2021-15:07:48] [E] Engine creation failed
[09/15/2021-15:07:48] [E] Engine set up failed
&&&& FAILED TensorRT.trtexec # /usr/src/tensorrt/bin/trtexec --buildOnly --fp16 --useDLACore=1 --allowGPUFallback=false --workspace=1024 --explicitBatch --onnx=model.onnx --saveEngine=model_engine.trt

Hi @holger.lange,

Could you please share us issurepro ONNX model to try from our end for better debugging.

Thank you.

Thank you!!! Attached are the model.onnx file and the original Unet2layer.py pytorch model.
model.onnx (536.6 KB)
Unet2layer.py (1.6 KB)

When I remove the torch.cat() in the lower layer it works. So something it doesn’t like with torch.cat() in that layer?

    self.conv1 = nn.Conv2d(4, 80, (3, 3), 2, 1, bias=True)
    self.conv2 = nn.Conv2d(80, 80, (3, 3), 2, 1, bias=True)
    self.conv9 = nn.Conv2d(80, 32, (3, 3), 1, 1, bias=True)
    self.conv10 = nn.Conv2d(32 + 4, 1, (3, 3), 1, 1, bias=True)
    self.conv_transpose9 = nn.ConvTranspose2d(80, 80, kernel_size=2, stride=2)
    self.conv_transpose10 = nn.ConvTranspose2d(32, 32, kernel_size=2, stride=2)

    x1 = F.relu(self.conv1(x))
    x2 = F.relu(self.conv2(x1))
    xup = self.conv_transpose9(x2, output_size=x1.size()[2:])
    x9 = F.relu(self.conv9(xup))
    xup = self.conv_transpose10(x9, output_size=x.size()[2:])
    d = F.relu(self.conv10(torch.cat((xup, x), 1)))

[09/14/2021-18:17:47] [I] [TRT]
[09/14/2021-18:17:47] [I] [TRT] --------------- Layers running on DLA:
[09/14/2021-18:17:47] [I] [TRT] {(Unnamed Layer* 0) [Convolution],(Unnamed Layer* 1) [Activation],(Unnamed Layer* 2) [Convolution],(Unnamed Layer* 3) [Activation],(Unnamed Layer* 4) [Deconvolution],(Unnamed Layer* 5) [Convolution],(Unnamed Layer* 6) [Activation],(Unnamed Layer* 7) [Deconvolution]}, {(Unnamed Layer* 9) [Convolution],(Unnamed Layer* 10) [Activation]},
[09/14/2021-18:17:47] [I] [TRT] --------------- Layers running on GPU:
[09/14/2021-18:17:47] [I] [TRT] 20 copy, input copy,
[09/14/2021-18:17:59] [I] [TRT] Detected 1 inputs and 1 output network tensors.
&&&& PASSED TensorRT.trtexec # /usr/src/tensorrt/bin/trtexec --buildOnly --fp16 --useDLACore=1 --allowGPUFallback=false --explicitBatch --workspace=16384 --onnx=model.onnx --saveEngine=model_engine.trt

@spolisetty torch.cat() is supported, right? Is there a different way to code it?

Hi,

torch.cat() is supported. It shouldn’t be a problem, we are looking into this issue. Please allow us sometime to get back on this.

Thank you.

I appreciate! Thanks a lot!

Hi,

Using model which you’ve shared we are unable to reproduce the issue on latest TensorRT version. If model which you’ve shared is after removing torch.cat(), we request you to please share issue repro ONNX model to try again.

If not we recommend you to please try on latest TensorRT version. You can also use TRT NGC container. TensorRT | NVIDIA NGC

Thank you.

I have to use JetPack 4.5 right now, and that is where I see the problem.

I am going to try JetPack 4.6 to see if the problem goes away. Thanks for looking into it.

1 Like

I upgraded to JetPack 4.6, which made this problem go away, but now I get another error, even with the 1 layer network that worked before?

Module_id 33 Severity 2 : NVMEDIA_DLA 684
Module_id 33 Severity 2 : Failed to bind input tensor. err : 0x00000b
Module_id 33 Severity 2 : NVMEDIA_DLA 2866
Module_id 33 Severity 2 : Failed to bind input tensor args. status: 0x000007

More details here:

[09/22/2021-00:42:26] [I] === Device Information ===
[09/22/2021-00:42:26] [I] Selected Device: Xavier
[09/22/2021-00:42:26] [I] Compute Capability: 7.2
[09/22/2021-00:42:26] [I] SMs: 8
[09/22/2021-00:42:26] [I] Compute Clock Rate: 1.377 GHz
[09/22/2021-00:42:26] [I] Device Global Memory: 31928 MiB
[09/22/2021-00:42:26] [I] Shared Memory per SM: 96 KiB
[09/22/2021-00:42:26] [I] Memory Bus Width: 256 bits (ECC disabled)
[09/22/2021-00:42:26] [I] Memory Clock Rate: 1.377 GHz
[09/22/2021-00:42:26] [I]
[09/22/2021-00:42:26] [I] TensorRT version: 8001
[09/22/2021-00:42:27] [I] [TRT] [MemUsageChange] Init CUDA: CPU +353, GPU +0, now: CPU 371, GPU 3959 (MiB)
[09/22/2021-00:42:27] [I] Start parsing network model
[09/22/2021-00:42:27] [I] [TRT] ----------------------------------------------------------------
[09/22/2021-00:42:27] [I] [TRT] Input filename: model.onnx
[09/22/2021-00:42:27] [I] [TRT] ONNX IR version: 0.0.4
[09/22/2021-00:42:27] [I] [TRT] Opset version: 7
[09/22/2021-00:42:27] [I] [TRT] Producer name: pytorch
[09/22/2021-00:42:27] [I] [TRT] Producer version: 1.3
[09/22/2021-00:42:27] [I] [TRT] Domain:
[09/22/2021-00:42:27] [I] [TRT] Model version: 0
[09/22/2021-00:42:27] [I] [TRT] Doc string:
[09/22/2021-00:42:27] [I] [TRT] ----------------------------------------------------------------
[09/22/2021-00:42:27] [I] Finish parsing network model
[09/22/2021-00:42:27] [I] [TRT] [MemUsageChange] Init CUDA: CPU +0, GPU +0, now: CPU 372, GPU 3959 (MiB)
[09/22/2021-00:42:27] [I] [TRT] [MemUsageSnapshot] Builder begin: CPU 372 MiB, GPU 3959 MiB
[09/22/2021-00:42:27] [I] [TRT] ---------- Layers Running on DLA ----------
[09/22/2021-00:42:27] [I] [TRT] [DlaLayer] {ForeignNode[node_of_7…node_of_output]}
[09/22/2021-00:42:27] [I] [TRT] ---------- Layers Running on GPU ----------
[09/22/2021-00:42:27] [I] [TRT] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +225, GPU +226, now: CPU 599, GPU 4194 (MiB)
[09/22/2021-00:42:29] [I] [TRT] [MemUsageChange] Init cuDNN: CPU +307, GPU +312, now: CPU 906, GPU 4506 (MiB)
[09/22/2021-00:42:29] [W] [TRT] Detected invalid timing cache, setup a local cache instead
Module_id 33 Severity 2 : NVMEDIA_DLA 684
Module_id 33 Severity 2 : Failed to bind input tensor. err : 0x00000b
Module_id 33 Severity 2 : NVMEDIA_DLA 2866
Module_id 33 Severity 2 : Failed to bind input tensor args. status: 0x000007
[09/22/2021-00:42:35] [I] [TRT] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +0, GPU +0, now: CPU 906, GPU 4508 (MiB)
[09/22/2021-00:42:35] [E] Error[1]: [nvdlaUtils.cpp::submit::198] Error Code 1: DLA (Failure to submit program to DLA engine.)
[09/22/2021-00:42:35] [E] Error[2]: [builder.cpp::buildSerializedNetwork::417] Error Code 2: Internal Error (Assertion enginePtr != nullptr failed.)

The problem seems to be the torch.cat() with the input tensor x in the last layer. Without this concatenation, everything works fine.

d = F.relu(self.conv10(torch.cat((xup, x), 1)))

Hi @holger.lange,

We recommend you to please post your concern on the Jetson forum to get better help.

Thank you.

As using input x directly in a skip connection creates a problem with the DLA, I ran input x through a dummy conv2d layer with fixed weights - now it works with JetPack 4.6!

self.conv0 = nn.Conv2d(4, 4, (1, 1), 1, 0, bias=False)
self.conv0.weight = torch.nn.Parameter(torch.ones_like(self.conv0.weight), requires_grad=False)

x0 = F.relu(self.conv0(x))

d = F.relu(self.conv10(torch.cat((xup, x0), 1)))

Thanks a lot for your help!

1 Like