Run a part of DNN on DLA and part of DNN on GPU

Hi all,

This question is regarding TensorRT on DLA. Is there a way to run execute parts of model concurrently on DLA and GPU? That is I want to exploit the inter-op parallelism in the DNNs by mapping each Op to GPU, DLA0 and DLA1. Is it possible to perform this? Would it be possible to provide a sample code to perform this? Anyhelp/pointers is greatly appreciated.
For example, in the given sample DNN below I would like to execute three convs concurrently on GPU, DLA0 and DLA1.

Hi,

You can try to build a network manually with TensorRT API.

However, since GPU and DLA store the intermediate tensor in different places, this might cause some extra memory transfer overhead.

Thanks.

Thanks.
I was going over the example at here and it seems like they show a simple example for C++ API. I was able to find a python API [here] (IBuilderConfig — NVIDIA TensorRT Standard Python API Documentation 8.4.3 documentation) describing the same. So I guess both Python and C++ TensorRT API support layerwise mapping.

Would be great if I could get a simple starter code in Python.

Hi,

Yes. Both C++ and python are supported.
You can find the python example in section 6.4.2. Python.

Thanks.

Hi!

Thanks for your reply. I was able to create the above mentioned DNN in TensorRT using Python API but I run into a error.

import os
import sys
import torch
import torch.nn as nn
import numpy as np
import pycuda.autoinit
import pycuda.driver as cuda
import tensorrt as trt


# You can set the logger severity higher to suppress messages (or lower to display more messages).
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)


class Block(nn.Module):
    def __init__(self):
        super(Block, self).__init__()
        self.conv1 = nn.Conv2d(384, 384, kernel_size=1, padding=0, bias=False)
        self.conv2 = nn.Conv2d(384, 384, kernel_size=1, padding=0, bias=False)
        self.conv3 = nn.Conv2d(384, 384, kernel_size=1, padding=0, bias=False)
        self.conv4 = nn.Conv2d(384, 384, kernel_size=1, padding=0, bias=False)

    def forward(self, x):
        out1 = self.conv1(x)
        out2 = self.conv2(out1)
        out3 = self.conv3(x)
        out4 = self.conv4(x)
        out = torch.cat((out2, out3, out4), 1)
        return out



class ModelData(object):
    INPUT_NAME = "x"
    INPUT_SHAPE = (1, 384, 56, 56)
    OUTPUT_NAME = "out"
    OUTPUT_SIZE = (1, 1152, 56, 56)
    DTYPE = trt.float16


def populate_network(config, network, weights):
    # Configure the network layers based on the weights provided.
    input_tensor = network.add_input(name=ModelData.INPUT_NAME, dtype=ModelData.DTYPE, shape=ModelData.INPUT_SHAPE)

    # padding = network.add_padding_nd(input=input_tensor,pre_padding=(1,1),post_padding=(0,0))

    conv1_w = weights["conv1.weight"].numpy()
    conv1 = network.add_convolution_nd(input=input_tensor, num_output_maps=384, kernel_shape=(1, 1), kernel=conv1_w)
    # print(conv1)
    if config.can_run_on_DLA(conv1):
        print("conv1 can run on DLA")
        config.set_device_type(conv1, trt.DeviceType.GPU)

    conv2_w = weights["conv2.weight"].numpy()
    conv2 = network.add_convolution_nd(input=conv1.get_output(0), num_output_maps=384, kernel_shape=(1, 1), kernel=conv2_w)
    # print(conv2)
    if config.can_run_on_DLA(conv2):
        print("conv2 can run on DLA")
        config.set_device_type(conv2, trt.DeviceType.GPU)



    conv3_w = weights["conv3.weight"].numpy()
    conv3 = network.add_convolution_nd(input=input_tensor, num_output_maps=384, kernel_shape=(1, 1), kernel=conv3_w)
    # print(conv3)
    if config.can_run_on_DLA(conv3):
        print("conv3 can run on DLA")
        config.set_device_type(conv3, trt.DeviceType.DLA)

    conv4_w = weights["conv4.weight"].numpy()
    conv4 = network.add_convolution_nd(input=input_tensor, num_output_maps=384, kernel_shape=(1, 1), kernel=conv4_w)
    if config.can_run_on_DLA(conv4):
        print("conv4 can run on DLA")
        config.set_device_type(conv4, trt.DeviceType.DLA)
    # print(conv4)

    concat = network.add_concatenation([conv2.get_output(0), conv3.get_output(0), conv4.get_output(0)])
    # print(concat)
    concat.get_output(0).name = ModelData.OUTPUT_NAME
    network.mark_output(tensor=concat.get_output(0))


def build_engine(weights):
    # For more information on TRT basics, refer to the introductory samples.
    builder = trt.Builder(TRT_LOGGER)
    network = builder.create_network(EXPLICIT_BATCH)
    config = builder.create_builder_config()
    # config.set_flag(trt.BuilderFlag.GPU_FALLBACK)
    config.set_flag(trt.BuilderFlag.FP16)
    trt.BuilderFlag.GPU_FALLBACK = False
    trt.IBuilderConfig.default_device_type = trt.DeviceType.DLA
    trt.IBuilderConfig.DLA_core = 0
    # config.DLA_CORE = 0
    runtime = trt.Runtime(TRT_LOGGER)

    # config.max_workspace_size = GiB(4)
    config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, GiB(4))
    # Populate the network using weights from the PyTorch model.
    populate_network(config, network, weights)
    # Build and return an engine.
    plan = builder.build_serialized_network(network, config)
    return runtime.deserialize_cuda_engine(plan)


def main():
    net = Block()
    net.half()
    net_state = net.state_dict()
    
    # print(type(net_state))
    # print(net_state.keys())
    # Do inference with TensorRT.
    engine = build_engine(net_state)

    # Build an engine, allocate buffers and create a stream.
    # For more information on buffer allocation, refer to the introductory samples.
    inputs, outputs, bindings, stream = allocate_buffers(engine)
    context = engine.create_execution_context()

    # case_num = load_random_test_case(mnist_model, pagelocked_buffer=inputs[0].host)
    # For more information on performing inference, refer to the introductory samples.
    # The common.do_inference function will return a list of outputs - we only have one in this case.
    [output] = do_inference_v2(context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream)

def GiB(val):
    return val * 1 << 30

# Allocates all buffers required for an engine, i.e. host/device inputs/outputs.
def allocate_buffers(engine):
    inputs = []
    outputs = []
    bindings = []
    stream = cuda.Stream()
    for binding in engine:
        size = trt.volume(engine.get_binding_shape(binding)) * EXPLICIT_BATCH
        dtype = trt.nptype(engine.get_binding_dtype(binding))
        # Allocate host and device buffers
        host_mem = cuda.pagelocked_empty(size, dtype)
        device_mem = cuda.mem_alloc(host_mem.nbytes)
        # Append the device buffer to device bindings.
        bindings.append(int(device_mem))
        # Append to the appropriate list.
        if engine.binding_is_input(binding):
            inputs.append(HostDeviceMem(host_mem, device_mem))
        else:
            outputs.append(HostDeviceMem(host_mem, device_mem))
    return inputs, outputs, bindings, stream

class HostDeviceMem(object):
    def __init__(self, host_mem, device_mem):
        self.host = host_mem
        self.device = device_mem

    def __str__(self):
        return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device)

    def __repr__(self):
        return self.__str__()

def do_inference_v2(context, bindings, inputs, outputs, stream):
    # Transfer input data to the GPU.
    [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs]
    # Run inference.
    context.execute_async_v2(bindings=bindings, stream_handle=stream.handle)
    # Transfer predictions back from the GPU.
    [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs]
    # Synchronize the stream
    stream.synchronize()
    # Return only the host outputs.
    return [out.host for out in outputs]

if __name__ == "__main__":
    main()

The execution error

conv1 can run on DLA
conv2 can run on DLA
conv3 can run on DLA
conv4 can run on DLA
[10/17/2022-15:25:26] [TRT] [E] 2: [nvdlaRunner.cpp::execute::49] Error Code 2: Internal Error (Assertion context.dlaContext != nullptr failed. )
[10/17/2022-15:25:26] [TRT] [E] 2: [builder.cpp::buildSerializedNetwork::636] Error Code 2: Internal Error (Assertion engine != nullptr failed. )
Traceback (most recent call last):
  File "example.py", line 174, in <module>
    main()
  File "example.py", line 114, in main
    engine = build_engine(net_state)
  File "example.py", line 103, in build_engine
    return runtime.deserialize_cuda_engine(plan)
TypeError: deserialize_cuda_engine(): incompatible function arguments. The following argument types are supported:
    1. (self: tensorrt.tensorrt.Runtime, serialized_engine: buffer) -> tensorrt.tensorrt.ICudaEngine

Invoked with: <tensorrt.tensorrt.Runtime object at 0xfffeec2d66b0>, None

I have two more questions,

  1. How to create and visualize the execution trace of DNN execution in TensorRT itself? or Should I use NSight for this also?

  2. I was looking over the APIs and there was only an option to map each layer to GPU or DLA but couldn’t specify which DLA to map a layer to.

if config.can_run_on_DLA(conv4):
        config.set_device_type(conv4, trt.DeviceType.DLA)

So could you confirm if its not possible to map two different layers in a DNN to two different DLA cores?

Thanks for your time!

Hi,

1. There are several helping tools from the TensorRT repository.
Please check if one of them can meet your requirement:

2. Unfortunately no.
The hardware index is set when runtime.
However, the setting is used across the whole inference.

If you really want to do this, might be you can try to separable the model into several sub-model and pass the intermediate tensor manually.

Thanks.

1 Like

This topic was automatically closed 14 days after the last reply. New replies are no longer allowed.

Also check out the DLA github page for samples and resources or to report issues: Recipes and tools for running deep learning workloads on NVIDIA DLA cores for inference applications.

We have a FAQ page that addresses some common questions that we see developers run into: Deep-Learning-Accelerator-SW/FAQ

1 Like