Model compilation fails for DLA with AssertionError actualOutputDims == expectedOutputDims failed

Hi,

I have a very dummy model that simply concatenates the output of a transposed convolution with another tensor. Both tensors have compatible dimensions and the concatenation is set to happen on the channel dimension (dim = 1; NCHW ordering).

However, when compiling the model to DLA using TRTExec (TRT 10.3, JetPack 6.1), I’ve found that the model experiences a warning during the conversion that says:

[W] [TRT] Failed to add layer: Assertion actualOutputDims == expectedOutputDims failed.

After the warning, the model still gets compiled, but when it is used in inference it never runs on DLA, instead, it runs on fully on GPU. I’ve attached a dummy version of the ONNX model and the log from TRTExec.

I’ve also seen that the model defaults “fully” to the GPU, even when other parts of the model could run on the DLA without problems (so, if you try to add more layers, before or after the concatenation, the model will be defaulted entirely to DLA). I wanted to know whether there is something wrong with the ONNX generation itself, or whether there is some sort of internal error in TRT10.3 that gets fixed in a future version. As far as I know, DLA supports concatenations over the channels axis, so I’m not sure what the source of the problem is.

Note: I know that it defaults entirely because Tegrastats reports no DLA usage when running inference (both with TRTExec and custom script based on TRT Python API). This happens for the model I attached, as well as any other model that contains the same layers, and some additional more (so, for example, if we add convolutional layers at the beginning there is no DLA usage reported on tegrastats).

Thanks a lot for your assistance.
Regards,

model.zip (6.9 MB)

Hi,

It looks like you built the engine with default fp32 precision.
DLA only supports fp16 and int8 precision so please run the trtexec with --fp16 or --int8.

Thanks.

Hi again,

Thanks for the response! So, I’ve also tried fp16 and int8 precision flags but the same error happens in both cases as well. Here is the TRTExec command employed:

/usr/src/tensorrt/bin/trtexec --onnx=model.onnx --saveEngine=model.engine --useDLACore=0 --allowGPUFallback --fp16 --verbose

And I also attach the verbose log with fp16 compilation (the same AssertionError occurs, and no DLA usage is reported).
trtexec.log (403.1 KB)

I have also checked that the issue persists, even when the padding is set to 0 (as specified in the supported DLA layers in the official documentation). I leave here the PyTorch code that I’m using to reproduce this behaviour on the aforementioned environment, just in case it helps (with padding 0, the following code will do):

import torch


class Model(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

        self.layer = torch.nn.ConvTranspose2d(1280, 96, 4, 2, 0, bias = False)


    def forward(self, *args) -> torch.Tensor:
        x, y = args[0], args[1]
        result = torch.cat([x, self.layer(y)], dim=1)
        return result

model = Model()
torch.onnx.export(model, (torch.randn(8, 96, 18, 18), torch.randn(8, 1280, 8, 8)), 'model.onnx')

The compilation of the model with zero padding and --fp16 flag also leads to the warning that I described earlier ([W] [TRT] Failed to add layer: Assertion actualOutputDims == expectedOutputDims failed.), which results in the model being fully executed on GPU even if other layers can run on DLA without issues).

Again, thank you for your assistance, it is very much appreciated.

Hi,

In case there are some issues when exporting the model into the ONNX format.
Would you mind verifying the model with onnxruntime to see if it can output the expected results?

Thanks.

Hi,

Thanks again for your response, much appreciated. I have tried running the model with onnx runtime with a fast script (attached next) and it seems to work just fine, as the output looks correct for the model with 0 padding mentioned in the last post (if you need this ONNX model, please let me know and I’ll attach it in this message or in a new message, as the original ONNX sent in the first post contains a DeConv layer with padding = 1, which should not run in the DLA based on its limitations).

import numpy as np
import onnxruntime as ort

input1 = np.random.rand(8, 1280, 8, 8).astype(np.float32)
input2 = np.random.rand(8, 96, 18, 18).astype(np.float32)

sess = ort.InferenceSession('model.onnx')
sess.run(None, {'onnx::ConvTranspose_1' : input1, 'onnx::Concat_0' : input2})

I hope the above information helps finding the problem. From what I have seen, the model works fine when running on the GPU, however, the compilation on DLA has this “AssertionError” and it looks like no layer of the model will run on the DLA even if it could (which makes me think that, perhaps, the previous warning is the result of a failed compilation on the DLA).

Again, thank you a lot for your time, it is much appreciated.

Hi,

We try to reduce the model to a single concat layer and also reproduce the same error.
So this issue should relate to the concat layer.

import torch


class Model(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, *args) -> torch.Tensor:
        x, y = args[0], args[1]
        result = torch.cat([x, y], dim=1)
        return result

model = Model()
torch.onnx.export(model, (torch.randn(8, 32, 8, 8), torch.randn(8, 32, 8, 8)), 'model.onnx')
...
[03/18/2025-05:56:41] [V] [TRT] Graph construction completed in 0.000298648 seconds.
[03/18/2025-05:56:41] [V] [TRT] ---------- Layers Running on DLA ----------
[03/18/2025-05:56:41] [V] [TRT] [DlaLayer] /Concat
[03/18/2025-05:56:41] [V] [TRT] ---------- Layers Running on GPU ----------
[03/18/2025-05:56:41] [V] [TRT] No layer is running on GPU
[03/18/2025-05:56:41] [V] [TRT] After adding DebugOutput nodes: 1 layers
[03/18/2025-05:56:41] [V] [TRT] After Myelin optimization: 1 layers
[03/18/2025-05:56:41] [V] [TRT] 	DLA Memory Pool Sizes: Managed SRAM = 1 MiB,	Local DRAM = 1024 MiB,	Global DRAM = 512 MiB
[03/18/2025-05:56:41] [W] [TRT] Failed to add layer: Assertion actualOutputDims == expectedOutputDims failed. 
[03/18/2025-05:56:41] [W] [TRT] Failed to add layer: Assertion actualOutputDims == expectedOutputDims failed. 
[03/18/2025-05:56:41] [W] [TRT] Failed to add layer: Assertion actualOutputDims == expectedOutputDims failed. 
[03/18/2025-05:56:41] [W] [TRT] Failed to add layer: Assertion actualOutputDims == expectedOutputDims failed. 
[03/18/2025-05:56:41] [W] [TRT] {ForeignNode[/Concat]} cannot be compiled by DLA, falling back to GPU.

We need to check with our internal team for more info and updates.
Thanks.

1 Like

Hi,

Thanks both for taking the time to check and also for updating me on the topic, very much appreciated. I’ll keep an eye on this topic as it would help me deploy some models that require this concatenation (e.g. models with skip connections among others).

Again, thanks for everything!

Hi,

The issue is fixed in our latest TensorRT 10.7 release for the Jetson platform.
Could you give it a try?

Packages

Installation guide

With TensorRT 10.7, we can run the above-mentioned concat-only model on the DLA without the dimension error:

$ TensorRT-10.7.0.23/bin/trtexec --onnx=model.onnx --useDLACore=0 --fp16 --memPoolSize=dlaSRAM:1 --inputIOFormats=fp16:chw16 --outputIOFormats=fp16:chw16 --verbose
...
[03/19/2025-04:38:06] [V] [TRT] ---------- Layers Running on DLA ----------
[03/19/2025-04:38:06] [V] [TRT] [DlaLayer] /Concat
[03/19/2025-04:38:06] [V] [TRT] ---------- Layers Running on GPU ----------
[03/19/2025-04:38:06] [V] [TRT] No layer is running on GPU
[03/19/2025-04:38:06] [V] [TRT] After adding DebugOutput nodes: 1 layers
[03/19/2025-04:38:06] [V] [TRT] After Myelin optimization: 1 layers
[03/19/2025-04:38:06] [V] [TRT]     DLA Memory Pool Sizes: Managed SRAM = 1 MiB,    Local DRAM = 1024 MiB,    Global DRAM = 512 MiB
[03/19/2025-04:38:06] [V] [TRT] Creating DLA tmp dir: /tmp/tensorrt-dla-build-z4XdHM
[03/19/2025-04:38:06] [V] [TRT] Creating DLA tmp dir: /tmp/tensorrt-dla-build-wISPxV
[03/19/2025-04:38:06] [V] [TRT] Creating DLA tmp dir: /tmp/tensorrt-dla-build-z1OfRi
[03/19/2025-04:38:06] [V] [TRT] Creating DLA tmp dir: /tmp/tensorrt-dla-build-kf4nEi
[03/19/2025-04:38:06] [V] [TRT] {ForeignNode[/Concat]} successfully offloaded to DLA.
    Memory consumption: Managed SRAM = 1 MiB,    Local DRAM = 2 MiB,    Global DRAM = 2 MiB
[03/19/2025-04:38:06] [V] [TRT] DLA Memory Consumption Summary:
[03/19/2025-04:38:06] [V] [TRT]     Number of DLA node candidates offloaded : 1 out of 1
[03/19/2025-04:38:06] [V] [TRT]     Total memory required by accepted candidates : Managed SRAM = 1 MiB,    Local DRAM = 2 MiB,    Global DRAM = 2 MiB
[03/19/2025-04:38:06] [V] [TRT] After DLA optimization: 1 layers
...
[03/19/2025-04:38:06] [V] [TRT] Engine Layer Information:
Layer(DLA): {ForeignNode[/Concat]}, Tactic: 0x0000000000000003, onnx::Concat_0 (Half[8,32:16,8,8]), onnx::Concat_1 (Half[8,32:16,8,8]) -> 2 (Half[8,64:16,8,8])
[03/19/2025-04:38:06] [I] [TRT] [MemUsageStats] Peak memory usage of TRT CPU/GPU memory allocators: CPU 0 MiB, GPU 0 MiB
[03/19/2025-04:38:06] [V] [TRT] Adding 1 engine(s) to plan file.
[03/19/2025-04:38:06] [V] [TRT] Adding 1 engine weights(s) to plan file.
[03/19/2025-04:38:06] [I] Engine built in 0.114208 sec.
...

Thanks.

This topic was automatically closed 14 days after the last reply. New replies are no longer allowed.