[TensorRT] Running a simple onnx model on Jetson Xavier DLA

Hi,
I have a simple python script which I am using to run TensorRT inference on Jetson Xavier for an onnx model (Tensorrt version 8.4.0 + cuda 11.4)

I wanted to run this inference purely on DLA, so i disabled gpu fallback.
I initially tried with a Resnet 50 onnx model, but it failed as some of the layers needed gpu fallback enabled.

So, I decided to write my own model with a couple of conv layers and couple of fully connected layers -

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
	# 3 input image channel, 6 output channels, 
	# 5x5 square convolution kernel
        self.conv1 = nn.Conv2d(3, 6, kernel_size=5, stride=1, padding=1)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=1)
	# Max pooling over a (2, 2) window
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.LazyLinear(120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.flatten(1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

But even when i run this simple model, it fails to run on the DLA, with an error

[07/08/2022-16:45:48] [TRT] [E] 9: [standardEngineBuilder.cpp::isValidDLAConfig::1539] Error Code 9: Internal Error (Default DLA is enabled but layer Flatten_6 + reshape_before_Gemm_7 is not supported on DLA and falling back to GPU is not enabled.)
  1. Is it because the graph optimization tried fusing the layers and these fused layers can’t run on the DLA? Is there any way to disable it?
  2. Is there any standard onnx model in the onnx model zoo (GitHub - onnx/models: A collection of pre-trained, state-of-the-art models in the ONNX format) or custom model example that can be run purely on DLA without any gpu fallback?

Hi,
Request you to share the ONNX model and the script if not shared already so that we can assist you better.
Alongside you can try few things:

  1. validating your model with the below snippet

check_model.py

import sys
import onnx
filename = yourONNXmodel
model = onnx.load(filename)
onnx.checker.check_model(model).
2) Try running your model with trtexec command.

In case you are still facing issue, request you to share the trtexec “”–verbose"" log for further debugging
Thanks!

I have checked my onnx model
I have also shared the relevant portion of the script
Whenever i try to upload model, i get some error. But anyway i have shared the model layers via code

Hi,

We are moving this post to Jetson Xavier NX forum to get better help.

Thank you.

Hi,

Based on your log, the failure is caused by the flatten layer.
You can find the DLA support matrix below:

Would you mind sharing a complete source that can generate an ONNX model?
So we can check it on our environment directly.

Thanks.

Hi,
Thanks for the reply.

Sure. Attached below is the complete source to generate the ONNX model - dla_neural_net.onnx
The model is not trained to save time, but rather initialised with xavier initialization, you can train it if you want to improve accuracy which wasn’t my main purpose

dla_model_onnx.py (4.1 KB)

Since you mentioned that the flatten layer is the one not supported by the DLA, can I somehow replace it so that the model can be run purely on the DLA? Or is there any other example of a simple neural net running purely on DLA?

Hello @AastaLLL , could you generate the model from the sources? Do you have any suggestion on how to make it run purely on DLA?

Hi,

Sorry for the late reply.

We can generate the model and also confirm the failure comes from the flatten layer.

[08/04/2022-16:20:55] [I] [TRT] ---------- Layers Running on DLA ----------
[08/04/2022-16:20:55] [I] [TRT] [DlaLayer] {ForeignNode[Conv_0...MaxPool_5]}
[08/04/2022-16:20:55] [I] [TRT] [DlaLayer] {ForeignNode[Relu_8]}
[08/04/2022-16:20:55] [I] [TRT] [DlaLayer] {ForeignNode[Gemm_9]}
[08/04/2022-16:20:55] [I] [TRT] [DlaLayer] {ForeignNode[Relu_10]}
[08/04/2022-16:20:55] [I] [TRT] [DlaLayer] {ForeignNode[Gemm_11]}
[08/04/2022-16:20:55] [I] [TRT] ---------- Layers Running on GPU ----------
[08/04/2022-16:20:55] [I] [TRT] [GpuLayer] SHUFFLE: Flatten_6 + reshape_before_Gemm_7
[08/04/2022-16:20:55] [I] [TRT] [GpuLayer] CONVOLUTION: Gemm_7
[08/04/2022-16:20:55] [I] [TRT] [GpuLayer] SHUFFLE: reshape_after_Gemm_7 + shuffle onnx::Relu_18 0
[08/04/2022-16:20:55] [I] [TRT] [GpuLayer] SHUFFLE: shuffle onnx::Gemm_19 0 + reshape_before_Gemm_9
[08/04/2022-16:20:55] [I] [TRT] [GpuLayer] SHUFFLE: reshape_after_Gemm_9 + shuffle onnx::Relu_20 0
[08/04/2022-16:20:55] [I] [TRT] [GpuLayer] SHUFFLE: shuffle onnx::Gemm_21 0 + reshape_before_Gemm_11
[08/04/2022-16:20:55] [I] [TRT] [GpuLayer] SHUFFLE: reshape_after_Gemm_11

Is it possible to change the layer without using reshape/flatten?
Thanks.

Thanks a lot for the reply!!
Sorry, I am a newbie but what would be the substitute for reshape/flatten which can run purely on dla? I don’t have specific requirements for a model, so if there is any other model which can run purely on DLA, you could point me to that. I just want to successfully run a CNN purely on DLA, it could be some std. pre-trained model or any other simple model similar to the one I wrote

Hi,

Do you want a classifier?

If yes, maybe you can try to model that doesn’t have the FC layer.
For example, GoogleNet.

Thanks.

Yes, I was looking for a classifier.
Unfortunately googlenet doesnt work purely on DLA. I have tried by downloading the standard googlenet from torch (googlenet-1378be20.pth) and then converting it to onnx.

The error I get is -

[07/29/2022-08:37:07] [TRT] [E] 9: [standardEngineBuilder.cpp::isValidDLAConfig::1539] Error Code 9: Internal Error (Default DLA is enabled but layer {ForeignNode[Gather_1...Gather_17]} is not supported on DLA and falling back to GPU is not enabled.)

Hi,

Have you set the last fc layer as output?
Since the softmax DLA support is added from Orin, you will need to mark the fc layer as output.

For example, we can run it with the GoogleNet model included in the TensorRT:

$ cd /usr/src/tensorrt/data/googlenet/
$ /usr/src/tensorrt/bin/trtexec --deploy=googlenet.prototxt --output=loss3/classifier --useDLACore=0
&&&& RUNNING TensorRT.trtexec [TensorRT v8201] # /usr/src/tensorrt/bin/trtexec --deploy=googlenet.prototxt --output=loss3/classifier --useDLACore=0
[08/10/2022-14:43:03] [I] === Model Options ===
[08/10/2022-14:43:03] [I] Format: Caffe
[08/10/2022-14:43:03] [I] Model:
[08/10/2022-14:43:03] [I] Prototxt: googlenet.prototxt
[08/10/2022-14:43:03] [I] Output: loss3/classifier
...
[08/10/2022-14:43:05] [I] Start parsing network model
[08/10/2022-14:43:05] [I] Finish parsing network model
[08/10/2022-14:43:07] [I] [TRT] ---------- Layers Running on DLA ----------
[08/10/2022-14:43:07] [I] [TRT] [DlaLayer] {ForeignNode[conv1/7x7_s2...loss3/classifier]}
[08/10/2022-14:43:07] [I] [TRT] ---------- Layers Running on GPU ----------
[08/10/2022-14:43:08] [I] [TRT] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +226, GPU +201, now: CPU 741, GPU 8478 (MiB)
[08/10/2022-14:43:09] [I] [TRT] [MemUsageChange] Init cuDNN: CPU +307, GPU +310, now: CPU 1048, GPU 8788 (MiB)
[08/10/2022-14:43:09] [I] [TRT] Local timing cache in use. Profiling results in this builder pass will not be stored.
[08/10/2022-14:43:44] [I] [TRT] Detected 1 inputs and 1 output network tensors.
...
&&&& PASSED TensorRT.trtexec [TensorRT v8201] # /usr/src/tensorrt/bin/trtexec --deploy=googlenet.prototxt --output=loss3/classifier --useDLACore=0

Thanks.

This topic was automatically closed 14 days after the last reply. New replies are no longer allowed.