How can I customize matrix multiplication on DLA

I know that the DLA on ORIN has strong computing power for INT8, so I want to offload some of the INT8 matrix multiplication to DLA, but I found that DLA does not support matrix multiplication, and it recommends that I use convolution operations instead. I have tried many times but still cannot achieve it. Can you give me an example to achieve this function?

Hi,

What kind of matrix multiplication do you want to use?
Is it possible to rewrite it as a TensorRT layer?

DLA is mainly used by TensorRT but can be run standalone with the cuDLA library.

https://docs.nvidia.com/cuda/cudla-api/index.html

Thanks.

Thanks for reply.

I am a novice in DLA and TensorRT. Can you give me a specific example of using DLA to complete a simple matrix multiplication or matrix-vector multiplication calculation (Just like you said, rewrite as a TensorRT layer)?

Also, I tried to understand cuDLA, but it seems that it can only submit some tasks, instead of customizing some functions to use gpu like CUDA to use DLA.

I would really appreciate it if getting some help!

Hi,

cuDLA is a library to allows users to deploy a DLA plan without using TensorRT.
But the DLA plan still needs to be compiled with TensorRT.

The simplest way is to define a PyTorch model, export it to ONNX format, and deploy it with trtexec.

Here is a toy version for your reference:

  1. Install PyTorch and ONNX, this step can be done in either on Jetson or desktop device
  2. Define a matrix multiplication model
    For example: C (128x128) = A (128x1024) x B (1024x128)
import torch
import torch.onnx

class TinyModel(torch.nn.Module):

    def __init__(self):
        super(TinyModel, self).__init__()

    def forward(self, x1, x2):
        x = torch.matmul(x1, x2)
        return x

tinymodel = TinyModel()

tinymodel.eval()

x1 = torch.randn(128, 1024, requires_grad=True)
x2 = torch.randn(1024, 128, requires_grad=True)

torch.onnx.export(tinymodel, (x1,x2), "matmul.onnx", export_params=True,
                  opset_version=10, input_names=['matrixA','matrixB'], output_names=['output'])
  1. Copy the matmul.onnx to the Jetson device
  2. Deploy it via trtexec with DLA

Thanks.

Thanks for reply.

When I tried your method, I found an error when building the engine (like the following output): DLA does not support matrix multiplication.

&&&& RUNNING TensorRT.trtexec [TensorRT v8602] # /usr/src/tensorrt/bin/trtexec --onnx=torch_matmul.onnx --useDLACore=0 --fp16
[08/14/2024-14:08:46] [I] === Model Options ===
[08/14/2024-14:08:46] [I] Format: ONNX
[08/14/2024-14:08:46] [I] Model: torch_matmul.onnx
[08/14/2024-14:08:46] [I] Output:
[08/14/2024-14:08:46] [I] === Build Options ===
[08/14/2024-14:08:46] [I] Max batch: explicit batch
[08/14/2024-14:08:46] [I] Memory Pools: workspace: default, dlaSRAM: default, dlaLocalDRAM: default, dlaGlobalDRAM: default
[08/14/2024-14:08:46] [I] minTiming: 1
[08/14/2024-14:08:46] [I] avgTiming: 8
[08/14/2024-14:08:46] [I] Precision: FP32+FP16
[08/14/2024-14:08:46] [I] LayerPrecisions:
[08/14/2024-14:08:46] [I] Layer Device Types:
[08/14/2024-14:08:46] [I] Calibration:
[08/14/2024-14:08:46] [I] Refit: Disabled
[08/14/2024-14:08:46] [I] Version Compatible: Disabled
[08/14/2024-14:08:46] [I] ONNX Native InstanceNorm: Disabled
[08/14/2024-14:08:46] [I] TensorRT runtime: full
[08/14/2024-14:08:46] [I] Lean DLL Path:
[08/14/2024-14:08:46] [I] Tempfile Controls: { in_memory: allow, temporary: allow }
[08/14/2024-14:08:46] [I] Exclude Lean Runtime: Disabled
[08/14/2024-14:08:46] [I] Sparsity: Disabled
[08/14/2024-14:08:46] [I] Safe mode: Disabled
[08/14/2024-14:08:46] [I] Build DLA standalone loadable: Disabled
[08/14/2024-14:08:46] [I] Allow GPU fallback for DLA: Disabled
[08/14/2024-14:08:46] [I] DirectIO mode: Disabled
[08/14/2024-14:08:46] [I] Restricted mode: Disabled
[08/14/2024-14:08:46] [I] Skip inference: Disabled
[08/14/2024-14:08:46] [I] Save engine:
[08/14/2024-14:08:46] [I] Load engine:
[08/14/2024-14:08:46] [I] Profiling verbosity: 0
[08/14/2024-14:08:46] [I] Tactic sources: Using default tactic sources
[08/14/2024-14:08:46] [I] timingCacheMode: local
[08/14/2024-14:08:46] [I] timingCacheFile:
[08/14/2024-14:08:46] [I] Heuristic: Disabled
[08/14/2024-14:08:46] [I] Preview Features: Use default preview flags.
[08/14/2024-14:08:46] [I] MaxAuxStreams: -1
[08/14/2024-14:08:46] [I] BuilderOptimizationLevel: -1
[08/14/2024-14:08:46] [I] Input(s)s format: fp32:CHW
[08/14/2024-14:08:46] [I] Output(s)s format: fp32:CHW
[08/14/2024-14:08:46] [I] Input build shapes: model
[08/14/2024-14:08:46] [I] Input calibration shapes: model
[08/14/2024-14:08:46] [I] === System Options ===
[08/14/2024-14:08:46] [I] Device: 0
[08/14/2024-14:08:46] [I] DLACore: 0
[08/14/2024-14:08:46] [I] Plugins:
[08/14/2024-14:08:46] [I] setPluginsToSerialize:
[08/14/2024-14:08:46] [I] dynamicPlugins:
[08/14/2024-14:08:46] [I] ignoreParsedPluginLibs: 0
[08/14/2024-14:08:46] [I]
[08/14/2024-14:08:46] [I] === Inference Options ===
[08/14/2024-14:08:46] [I] Batch: Explicit
[08/14/2024-14:08:46] [I] Input inference shapes: model
[08/14/2024-14:08:46] [I] Iterations: 10
[08/14/2024-14:08:46] [I] Duration: 3s (+ 200ms warm up)
[08/14/2024-14:08:46] [I] Sleep time: 0ms
[08/14/2024-14:08:46] [I] Idle time: 0ms
[08/14/2024-14:08:46] [I] Inference Streams: 1
[08/14/2024-14:08:46] [I] ExposeDMA: Disabled
[08/14/2024-14:08:46] [I] Data transfers: Enabled
[08/14/2024-14:08:46] [I] Spin-wait: Disabled
[08/14/2024-14:08:46] [I] Multithreading: Disabled
[08/14/2024-14:08:46] [I] CUDA Graph: Disabled
[08/14/2024-14:08:46] [I] Separate profiling: Disabled
[08/14/2024-14:08:46] [I] Time Deserialize: Disabled
[08/14/2024-14:08:46] [I] Time Refit: Disabled
[08/14/2024-14:08:46] [I] NVTX verbosity: 0
[08/14/2024-14:08:46] [I] Persistent Cache Ratio: 0
[08/14/2024-14:08:46] [I] Inputs:
[08/14/2024-14:08:46] [I] === Reporting Options ===
[08/14/2024-14:08:46] [I] Verbose: Disabled
[08/14/2024-14:08:46] [I] Averages: 10 inferences
[08/14/2024-14:08:46] [I] Percentiles: 90,95,99
[08/14/2024-14:08:46] [I] Dump refittable layers:Disabled
[08/14/2024-14:08:46] [I] Dump output: Disabled
[08/14/2024-14:08:46] [I] Profile: Disabled
[08/14/2024-14:08:46] [I] Export timing to JSON file:
[08/14/2024-14:08:46] [I] Export output to JSON file:
[08/14/2024-14:08:46] [I] Export profile to JSON file:
[08/14/2024-14:08:46] [I]
[08/14/2024-14:08:46] [I] === Device Information ===
[08/14/2024-14:08:46] [I] Selected Device: Orin
[08/14/2024-14:08:46] [I] Compute Capability: 8.7
[08/14/2024-14:08:46] [I] SMs: 16
[08/14/2024-14:08:46] [I] Device Global Memory: 62841 MiB
[08/14/2024-14:08:46] [I] Shared Memory per SM: 164 KiB
[08/14/2024-14:08:46] [I] Memory Bus Width: 256 bits (ECC disabled)
[08/14/2024-14:08:46] [I] Application Compute Clock Rate: 1.3 GHz
[08/14/2024-14:08:46] [I] Application Memory Clock Rate: 1.3 GHz
[08/14/2024-14:08:46] [I]
[08/14/2024-14:08:46] [I] Note: The application clock rates do not reflect the actual clock rates that the GPU is currently running at.
[08/14/2024-14:08:46] [I]
[08/14/2024-14:08:46] [I] TensorRT version: 8.6.2
[08/14/2024-14:08:46] [I] Loading standard plugins
[08/14/2024-14:08:46] [I] [TRT] [MemUsageChange] Init CUDA: CPU +2, GPU +0, now: CPU 33, GPU 7736 (MiB)
[08/14/2024-14:08:51] [I] [TRT] [MemUsageChange] Init builder kernel library: CPU +1154, GPU +1105, now: CPU 1223, GPU 8880 (MiB)
[08/14/2024-14:08:51] [I] Start parsing network model.
[08/14/2024-14:08:51] [I] [TRT] ----------------------------------------------------------------
[08/14/2024-14:08:51] [I] [TRT] Input filename: torch_matmul.onnx
[08/14/2024-14:08:51] [I] [TRT] ONNX IR version: 0.0.5
[08/14/2024-14:08:51] [I] [TRT] Opset version: 10
[08/14/2024-14:08:51] [I] [TRT] Producer name: pytorch
[08/14/2024-14:08:51] [I] [TRT] Producer version: 2.1.0
[08/14/2024-14:08:51] [I] [TRT] Domain:
[08/14/2024-14:08:51] [I] [TRT] Model version: 0
[08/14/2024-14:08:51] [I] [TRT] Doc string:
[08/14/2024-14:08:51] [I] [TRT] ----------------------------------------------------------------
[08/14/2024-14:08:51] [I] Finished parsing network model. Parse time: 0.00176879
[08/14/2024-14:08:51] [W] [TRT] /MatMul: MatMul is unsupported on DLA without kGPU_FALLBACK flag, relax constraints or use Convolution layer instead if fail.
[08/14/2024-14:08:51] [W] [TRT] /MatMul: MatMul is unsupported on DLA, relax constraints or use Convolution layer instead.
[08/14/2024-14:08:51] [E] Error[2]: [network.cpp::operator()::2789] Error Code 2: Internal Error (Assertion allowGPUFallback failed. Layer ‘/MatMul’ is not supported on DLA but GPU fallback is not enabled.)
[08/14/2024-14:08:51] [E] Error[4]: [network.cpp::validate::2901] Error Code 4: Internal Error (DLA validation failed)
[08/14/2024-14:08:51] [E] Engine could not be created from network
[08/14/2024-14:08:51] [E] Building engine failed
[08/14/2024-14:08:51] [E] Failed to create engine from model or file.
[08/14/2024-14:08:51] [E] Engine set up failed
&&&& FAILED TensorRT.trtexec [TensorRT v8602] # /usr/src/tensorrt/bin/trtexec --onnx=torch_matmul.onnx --useDLACore=0 --fp16

I also tried to rewrite the matrix multiplication operator in tensorrt before, and I encountered the same error.
Could you please give me some other suggestions? Thx

Btw, the error message prompts to turn on --allowGPUFallback, but in this case, all matrix multiplications fall back to the GPU for calculation (instead of DLA), which runs counter to my purpose of offloading matrix multiplication operations to DLA.

Are there any solutions?

Hi,

Please see our DLA support matrix below:

If the second matrix is a constant.
The layer can be rewritted with 1x1 Conv which can run on DLA directly.

Thanks.

Thanks, so, we can only implement matrix multiplication on DLA by using convolutional layers or fully connected layers?

By the way, can programs running on DLA be profiled in detail? Like CUDA kernels? On Nsys, I can only observe whether DLA is working, but I can’t see its utilization and throughput, which is not helpful for optimizing algorithms.

I will appreciate it if u could give any advices

Hi,

You can follow below script to wrap matmul into DLA runnable.

On Nsight System, please enable nvmedia trace to gather DLA info:

Thanks.

Thanks for reply.

Following your instructions, I still encounter two problems, hope to get your help, thank you!

  1. I can now run matrix multiplication on DLA, but the problem is that when I perform a simple matrix multiplication calculation (repeatedly run 100,000 times of 33 and 31 matrix multiplication calculations), I found that whether it is int8 precision or fp16 precision, the running speed on DLA will be nearly 15 times slower than that on GPU, which is unbelievable because NVIDIA claims that the computing power of Jetson Agx Orin’s DLA under int8 precision conditions can exceed more than half of the computing power of GPU. So, in theory, under the condition of int8 precision, the running speed of DLA is at most twice slower than that of GPU, but I got a result of 15 times. (But when I used the mobilenet model to experiment on the DLA, I found that the difference in running speed between the two was in line with expectations. The running speed of the DLA was only about twice that of the GPU.)


    (Above is the evaluation of matrix multiplication with fp16 and int8 precision performed on dla and gpu respectively)

    (Above is the evaluation of mobilenet inference with fp16 and int8 precision performed on dla and gpu respectively)

  2. I tried the nvmedia trace you mentioned, and found that the data about the DLA was just some timestamps, which is not very meaningful. The data I want to get is the utilization and memory throughput on the DLA, not the simple timestamps. (Maybe I did something wrong in the operation. I can only get the timestamps running on the DLA from the Nsight system. Please advise.)


There is no update from you for a period, assuming this is not an issue anymore.
Hence, we are closing this topic. If need further support, please open a new one.
Thanks

Hi,

Could you share the test code with us?
We want to take a look to see if there is any format conversion that can be avoided.

Thanks.