Tensorrt Python API has a bug in DLA usage

Description

When trying to use the DLA for inferencing on JetsonXavier AGX via Tensorrt python API, I observe that the DLA0 is always the one used for inferencing irrespective of config.DLA_core = 1 or 0

Environment

Jetson Xavier AGX GPU and DLA
TensorRT Version: 8.4.0
CUDA Version: 11.4
CUDNN Version: 8.3.2
Python Version (if applicable): 3.8

Relevant Files

trt_inference.py (4.5 KB)
dla_usage.sh (672 Bytes)

Steps To Reproduce

Running the attached trt_inference.py (with onnx model and a jpg image as input arguments e.g. python3 trt_inference.py resnet50.onnx img.jpg) and checking the dla usage via attached dla_usage.sh as watch -n 1 dla_usage.sh shows that its always DLA0 that is used for inference with both DLA0 and DLA1 used for building the engine

Healing

The cpp APIs don’t have this issue and use the DLA device that is mentioned by the user

Hi,
Please refer to below links related custom plugin implementation and sample:

While IPluginV2 and IPluginV2Ext interfaces are still supported for backward compatibility with TensorRT 5.1 and 6.0.x respectively, however, we recommend that you write new plugins or refactor existing ones to target the IPluginV2DynamicExt or IPluginV2IOExt interfaces instead.

Thanks!

Sorry, I did not understand the reply. The issue is not any custom plugin implementation issue. I am also allowing gpu fallback for layers that cannot run on DLA. The issue is that through tensorrt python API I can do inference only on DLA 0

Hi,

Please refer to the following doc regarding DLA, which may be helpful to you.

We are moving this post to the Jetson AGX Xavier forum to get better help.

Thank you.

Hi,

Could you share the corresponding C++ source so we can compare it?
Thanks.

Please find the cpp source file and the corresponding CmakeLists.txt
It doesn’t allow me to upload the cpp file, so copying it directly here

Usage, once the product is built -
./trt_resnet resnet50.onnx tiger.jpg

As mentioned before, here the DLA usage works perfectly fine as per the user choice (0 or 1) unlike the python API

CMakeLists.txt (2.5 KB)

#include <iostream>
#include <fstream>
#include <NvInfer.h>
#include <memory>
#include <NvOnnxParser.h>
#include <vector>
#include <cuda_runtime_api.h>
#include <opencv2/imgcodecs.hpp>
#include <opencv2/core/cuda.hpp>
#include <opencv2/cudawarping.hpp>
#include <opencv2/core.hpp>
#include <opencv2/cudaarithm.hpp>
#include <algorithm>
#include <numeric>

using namespace cv;

// utilities ----------------------------------------------------------------------------------------------------------
// class to log errors, warnings, and other information during the build and inference phases
class Logger : public nvinfer1::ILogger
{
public:
    void log(Severity severity, const char* msg) noexcept override {
        // remove this 'if' if you need more logged info
        if ((severity == Severity::kERROR) || (severity == Severity::kINTERNAL_ERROR)) {
            std::cout << msg << "\n";
        }
    }
} gLogger;

// destroy TensorRT objects if something goes wrong
struct TRTDestroy
{
    template <class T>
    void operator()(T* obj) const
    {
        if (obj)
        {
            obj->destroy();
        }
    }
};

template <class T>
using TRTUniquePtr = std::unique_ptr<T, TRTDestroy>;

// calculate size of tensor
size_t getSizeByDim(const nvinfer1::Dims& dims)
{
    size_t size = 1;
    for (size_t i = 0; i < dims.nbDims; ++i)
    {
        size *= dims.d[i];
    }
    return size;
}

// get classes names
std::vector<std::string> getClassNames(const std::string& imagenet_classes)
{
    std::ifstream classes_file(imagenet_classes);
    std::vector<std::string> classes;
    if (!classes_file.good())
    {
        std::cerr << "ERROR: can't read file with classes names.\n";
        return classes;
    }
    std::string class_name;
    while (std::getline(classes_file, class_name))
    {
        classes.push_back(class_name);
    }
    return classes;
}

// preprocessing stage ------------------------------------------------------------------------------------------------
void preprocessImage(const std::string& image_path, float* gpu_input, const nvinfer1::Dims& dims)
{
    // read input image
    cv::Mat frame = cv::imread(image_path);
    if (frame.empty())
    {
        std::cerr << "Input image " << image_path << " load failed\n";
        return;
    }    
    cuda::GpuMat gpu_frame;
    // upload image to GPU
    gpu_frame.upload(frame);

    auto input_width = 224;
    auto input_height = 224;
    auto channels = 3;
    std::cout << "Input height : " << input_height << std::endl;
    std::cout << "Input width : " << input_width << std::endl;
    std::cout << "Channels : " << channels << std::endl;
    auto input_size = cv::Size(input_width, input_height);
    // resize
    cuda::GpuMat resized;
    cuda::resize(gpu_frame, resized, input_size, 0, 0, cv::INTER_NEAREST);
    // normalize
    cuda::GpuMat flt_image;
    resized.convertTo(flt_image, CV_32FC3, 1.f / 255.f);
    cuda::subtract(flt_image, cv::Scalar(0.485f, 0.456f, 0.406f), flt_image, cv::noArray(), -1);
    cuda::divide(flt_image, cv::Scalar(0.229f, 0.224f, 0.225f), flt_image, 1, -1);
    // to tensor

    std::vector<cuda::GpuMat> chw;
    for (size_t i = 0; i < channels; ++i)
    {
        chw.emplace_back(cuda::GpuMat(input_size, CV_32FC1, gpu_input + i * input_width * input_height));
    }
    cuda::split(flt_image, chw);
}

// post-processing stage ----------------------------------------------------------------------------------------------
void postprocessResults(float *gpu_output, const nvinfer1::Dims &dims, int batch_size)
{
    // get class names
    auto classes = getClassNames("../imagenet_classes.txt");

    // copy results from GPU to CPU
    std::vector<float> cpu_output(getSizeByDim(dims) * batch_size);
    cudaMemcpy(cpu_output.data(), gpu_output, cpu_output.size() * sizeof(float), cudaMemcpyDeviceToHost);

    // calculate softmax
    std::transform(cpu_output.begin(), cpu_output.end(), cpu_output.begin(), [](float val) {return std::exp(val);});
    auto sum = std::accumulate(cpu_output.begin(), cpu_output.end(), 0.0);
    // find top classes predicted by the model
    std::vector<int> indices(getSizeByDim(dims) * batch_size);
    std::cout << "Indices size: " << indices.size() << std::endl;
    std::cout << "Classes size: " << classes.size() << std::endl;
    std::cout << "cpu_output size: " << cpu_output.size() << std::endl;
    std::iota(indices.begin(), indices.end(), 0); // generate sequence 0, 1, 2, 3, ..., 999
    std::sort(indices.begin(), indices.end(), [&cpu_output](int i1, int i2) {return cpu_output[i1] > cpu_output[i2];});
    // print results

    int i = 0;
    if (classes.size() > indices[i])
    {
       std::cout << "class: " << classes[indices[i]] << " | " << std::endl;
    }
    std::cout << "confidence: " << 100 * cpu_output[indices[0]] / sum << "% | index: " << indices[0] << std::endl;
    /*while (cpu_output[indices[i]] / sum > 0.005)
    {
        if (classes.size() > indices[i])
        {
            std::cout << "class: " << classes[indices[i]] << " | " << std::endl;
        }
        else
        {
           std::cout <<"index beyond class" << std::endl;
        }
        std::cout << "confidence: " << 100 * cpu_output[indices[i]] / sum << "% | index: " << indices[i] << std::endl;
        ++i;
    }*/
}

// initialize TensorRT engine and parse ONNX model --------------------------------------------------------------------
void parseOnnxModel(const std::string& model_path, TRTUniquePtr<nvinfer1::ICudaEngine>& engine,
                    TRTUniquePtr<nvinfer1::IExecutionContext>& context)
{
    TRTUniquePtr<nvinfer1::IBuilder> builder{nvinfer1::createInferBuilder(gLogger)};
    const auto explicitBatch = 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
    TRTUniquePtr<nvinfer1::INetworkDefinition> network{builder->createNetworkV2(explicitBatch)};
    TRTUniquePtr<nvonnxparser::IParser> parser{nvonnxparser::createParser(*network, gLogger)};
    TRTUniquePtr<nvinfer1::IBuilderConfig> config{builder->createBuilderConfig()};
    // parse ONNX
    if (!parser->parseFromFile(model_path.c_str(), static_cast<int>(nvinfer1::ILogger::Severity::kINFO)))
    {
        std::cerr << "ERROR: could not parse the model.\n";
        return;
    }
    // allow TensorRT to use up to 1GB of GPU memory for tactic selection.
    config->setMaxWorkspaceSize(1ULL << 30);

    // we have only one image in batch
    builder->setMaxBatchSize(1);

    if (builder->getNbDLACores() == 0)
    {
            std::cerr << "Trying to use DLA core on a platform that doesn't have any DLA cores" << std::endl;
            assert("Error: use DLA core on a platfrom that doesn't have any DLA cores" && false);
    }
    bool allowGPUFallback = true;
    if (allowGPUFallback)
    {
            config->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK);
    }
   if (!config->getFlag(nvinfer1::BuilderFlag::kINT8))
   {
       // User has not requested INT8 Mode.
       // By default run in FP16 mode. FP32 mode is not permitted.
       config->setFlag(nvinfer1::BuilderFlag::kFP16);
    }
    config->setDefaultDeviceType(nvinfer1::DeviceType::kDLA);
    config->setDLACore(1);
    config->setFlag(nvinfer1::BuilderFlag::kSTRICT_TYPES);

    auto profile = builder->createOptimizationProfile();
    profile->setDimensions(network->getInput(0)->getName(), nvinfer1::OptProfileSelector::kMIN, nvinfer1::Dims4{ 1, 3, 224 , 224 });
    profile->setDimensions(network->getInput(0)->getName(), nvinfer1::OptProfileSelector::kOPT, nvinfer1::Dims4{ 1, 3, 224 , 224 });
    profile->setDimensions(network->getInput(0)->getName(), nvinfer1::OptProfileSelector::kMAX, nvinfer1::Dims4{ 1, 3, 224 , 224 });
    config->addOptimizationProfile(profile);

    // generate TensorRT engine optimized for the target platform
    engine.reset(builder->buildEngineWithConfig(*network, *config));
    context.reset(engine->createExecutionContext());
}

// main pipeline ------------------------------------------------------------------------------------------------------
int main(int argc, char* argv[])
{
   std::cout << "Start" << std::endl;
    if (argc < 3)
    {
        std::cerr << "usage: " << argv[0] << " model.onnx image.jpg\n";
        return -1;
    }
    std::string model_path(argv[1]);
    std::string image_path(argv[2]);
    int batch_size = 1;

    // initialize TensorRT engine and parse ONNX model
    TRTUniquePtr<nvinfer1::ICudaEngine> engine{nullptr};
    TRTUniquePtr<nvinfer1::IExecutionContext> context{nullptr};
    std::cout << "Parsing onnx model" << std::endl;
    parseOnnxModel(model_path, engine, context);
    std::cout << "Parsing onnx model done" << std::endl;

    std::cout << "Allocating i/p o/p memory" << std::endl;
    // get sizes of input and output and allocate memory required for input data and for output data
    std::vector<nvinfer1::Dims> input_dims; // we expect only one input
    std::vector<nvinfer1::Dims> output_dims; // and one output
    std::vector<void*> buffers(engine->getNbBindings()); // buffers for input and output data

    for (size_t i = 0; i < engine->getNbBindings(); ++i)
    {
        auto binding_size = getSizeByDim(engine->getBindingDimensions(i)) * batch_size * sizeof(float);
        cudaMalloc(&buffers[i], binding_size);

        if (engine->bindingIsInput(i))
        {
            input_dims.emplace_back(engine->getBindingDimensions(i));
        }
        else
        {
            output_dims.emplace_back(engine->getBindingDimensions(i));
        }
    }
    if (input_dims.empty() || output_dims.empty())
    {
        std::cerr << "Expect at least one input and one output for network\n";
        return -1;
    }
    std::cout << "Allocating i/p o/p memory done" << std::endl;

    // preprocess input data
    std::cout << "Preprocessing image" << std::endl;
    preprocessImage(image_path, (float *) buffers[0], input_dims[0]);
    std::cout << "Preprocessing image done" << std::endl;
    // inference
    std::cout << "Inference start" << std::endl;
    for(int i = 0; i < 1000; ++i)
    {
       context->enqueue(batch_size, buffers.data(), 0, nullptr);
    }
    std::cout << "Inference done" << std::endl;
    // postprocess results
    //std::cout << "Postprocessing results" << std::endl;
    //postprocessResults((float *) buffers[1], output_dims[0], batch_size);

    for (void* buf : buffers)
    {
        cudaFree(buf);
    }

    std::cout << "Freed all memory" << std::endl;
    std::cout << "End!" << std::endl;
    return 0;
}

Hi,

Thanks for the source to reproduce.

Confirmed that we can reproduce this issue internally.
Will share more information with you later.

Hi,

Please try the following update to specify the DLACore when runtime.

diff --git a/trt_inference.py b/trt_inference.py
index 87b40c5..d09cd8c 100644
--- a/trt_inference.py
+++ b/trt_inference.py
@@ -50,6 +50,7 @@ def build_engine(onnx_file_path):
    network = builder.create_network(*EXPLICIT_BATCH)
    parser = trt.OnnxParser(network, TRT_LOGGER)
    runtime = trt.Runtime(TRT_LOGGER)
+   runtime.DLA_core = 1
 
    # parse ONNX
    with open(onnx_file_path, 'rb') as model:

Thanks.

1 Like

Thanks for the answer provided
Can confirm that this fixes the issue. The DLA usage with python API depends now on what the user specifies.
So, it has to be provided to runtime only and not via config? Because in the cpp API it seems to work when provided via config

Hi,

The APIs used between python and C++ are different.
The python sample serialized the engine and deserialize it without setting the config again.

To align with C++, you can use the below function instead.
Confirmed that the DLA 1 can be active with the below change:

https://github.com/NVIDIA/TensorRT/blob/release/8.4/python/src/infer/pyCore.cpp#L927

diff --git a/trt_inference.py b/trt_inference.py
index 8113cc5..876b765 100644
--- a/trt_inference.py
+++ b/trt_inference.py
@@ -75,8 +75,9 @@ def build_engine(onnx_file_path):

    config1.add_optimization_profile(profile)

-   trt_model_object1 = builder.build_serialized_network(network, config1)
-   engine1 = runtime.deserialize_cuda_engine(trt_model_object1)
+   engine1 = builder.build_engine(network, config1)
+   #trt_model_object1 = builder.build_serialized_network(network, config1)
+   #engine1 = runtime.deserialize_cuda_engine(trt_model_object1)
    #engine1 = build_engine(network, config1)

    print("Completed creating Engine")

Thanks.

1 Like

The above solution works. Thanks

This topic was automatically closed 14 days after the last reply. New replies are no longer allowed.