Wrong inference results with dynamic batch size in C++ but not in Python

Description

I have a simple network that takes two input batches of images and concatenates them on the channel axis, i.e.

input_1 = batch_size x 64 x 64 x 3
input_2 = batch_size x 64 x 64 x 3

output = batch_size x 64 x 64 x 6

The batch size is a dynamic parameter and varies between 1 and 49.

I create the engine in Python and serialize it. Then load it in C++ and does inference. The inference in C++ give me wrong results though. The first image in each batch is correct, but all the rest is wrong. This is not the case if I use the Python API.

Environment

TensorRT Version: 7.0.0-1+cuda10.0
GPU Type: GeForce GTX 1070
Nvidia Driver Version: 440.100
CUDA Version: 10.0
CUDNN Version: 7.6.5.32-1+cuda10.0
Operating System + Version: Ubuntu 18.04
Python Version (if applicable): 2 and 3
TensorFlow Version (if applicable):
PyTorch Version (if applicable):
Baremetal or Container (if container which image + tag):

Relevant Files

Please attach or include links to any models, data, files, or scripts necessary to reproduce your issue. (Github repo, Google Drive, Dropbox, etc.)

Steps To Reproduce

I have included a small reproducible example.

Python script to generate the sample engine.

"""
build_engine.py
"""
import tensorrt as trt

TRT_LOGGER = trt.Logger(trt.Logger.WARNING)

def main():
    EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
    with trt.Builder(TRT_LOGGER) as builder, builder.create_network(EXPLICIT_BATCH) as network:
        input_1_tensor = network.add_input(name="input_1_tensor", dtype=trt.float32, shape=(-1, 64, 64, 3))
        input_2_tensor = network.add_input(name="input_2_tensor", dtype=trt.float32, shape=(-1, 64, 64, 3))
        concat_layer = network.add_concatenation([input_1_tensor, input_2_tensor])
        concat_layer.axis = 3
        network.mark_output(concat_layer.get_output(0))

        config = builder.create_builder_config()
        config.max_workspace_size = 1 << 20

        optim_profile = builder.create_optimization_profile()
        optim_profile.set_shape("input_1_tensor", (1, 64, 64, 3), (15, 64, 64, 3), (49, 64, 64, 3))
        optim_profile.set_shape("input_2_tensor", (1, 64, 64, 3), (15, 64, 64, 3), (49, 64, 64, 3))
        config.add_optimization_profile(optim_profile)

        engine = builder.build_engine(network, config)
        with open("sample.plan", "wb") as f:
            f.write(engine.serialize())


if __name__ == '__main__':
    main()

C++ file to make inference

#include <opencv2/opencv.hpp>
#include <NvInfer.h>
#include <cuda_runtime.h>

class Logger : public nvinfer1::ILogger
{
    void log(Severity severity, const char * msg) override
    {
      std::cout << msg << std::endl;
    }
} gLogger;

std::string load_plan(const std::string &plan_file) {
  std::ifstream planFile(plan_file);
  std::stringstream planBuffer;
  planBuffer << planFile.rdbuf();
  std::string string_plan = planBuffer.str();
  return string_plan;
}

int main() {
  // Batch dimensions
  int batch_size = 7;
  int height = 64;
  int width = 64;
  int channels = 3;

  // Input info
  std::string plan_path = "sample.plan";
  std::string input_tensor_1 = "input_1_tensor";
  std::string input_tensor_2 = "input_2_tensor";
  std::string loaded_plan = load_plan(plan_path);
  nvinfer1::IRuntime *runtime = nvinfer1::createInferRuntime(gLogger);

  // Prepare execution context
  std::string plan = load_plan(plan_path);
  nvinfer1::ICudaEngine *engine = runtime->deserializeCudaEngine(
          loaded_plan.data(), loaded_plan.size(), nullptr);
  nvinfer1::IExecutionContext *context = engine->createExecutionContext();
  context->setOptimizationProfile(0);

  // Get bindings
  int input1BindingIndex = engine->getBindingIndex(input_tensor_1.c_str());
  int input2BindingIndex = engine->getBindingIndex(input_tensor_2.c_str());

  // Input and output dims
  nvinfer1::Dims4 input_1_dims = nvinfer1::Dims4(batch_size, height, width, channels);
  nvinfer1::Dims4 input_2_dims = nvinfer1::Dims4(batch_size, height, width, channels);

  nvinfer1::Dims4
  output_dims = nvinfer1::Dims4(batch_size, height, width, channels * 2);

  size_t num_inputs = input_1_dims.d[0] * input_1_dims.d[1] * input_1_dims.d[2] * input_1_dims.d[3];
  size_t num_outputs = output_dims.d[0] * output_dims.d[1] * output_dims.d[2] * output_dims.d[3];

  std::vector<float> input_1_buffer_host;
  std::vector<float> input_2_buffer_host;
  std::vector<float> output_host;

  input_1_buffer_host.resize(num_inputs);
  input_2_buffer_host.resize(num_inputs);
  output_host.resize(num_outputs);

  int data_pointer = 0;
  int img_vol = height * width * channels;

  // Generate random batch of images
  for(int i = 0; i < batch_size; i++) {
    cv::Mat random_img_1(height, width, CV_32FC3, &input_1_buffer_host[data_pointer]);
    cv::Mat random_img_2(height, width, CV_32FC3, &input_2_buffer_host[data_pointer]);
    cv::randu(random_img_1, cv::Scalar(0, 0, 0), cv::Scalar(1, 1, 1));
    cv::randu(random_img_2, cv::Scalar(0, 0, 0), cv::Scalar(1, 1, 1));
    data_pointer += img_vol;
  }


  float *input_1_device, *input_2_device, *output_device;
  size_t num_input_bytes = num_inputs * sizeof(float);
  size_t num_output_bytes = num_outputs * sizeof(float);
  cudaMalloc(&input_1_device, num_input_bytes);
  cudaMalloc(&input_2_device, num_input_bytes);
  cudaMalloc(&output_device, num_output_bytes);

  std::vector<int> staticDims = {batch_size, height, width, channels};

  context->setBindingDimensions(input1BindingIndex, input_1_dims);
  context->setBindingDimensions(input2BindingIndex, input_2_dims);

  cudaMemcpy(input_1_device, input_1_buffer_host.data(), num_inputs, cudaMemcpyHostToDevice);
  cudaMemcpy(input_2_device, input_2_buffer_host.data(), num_inputs, cudaMemcpyHostToDevice);

  std::vector<void*> bindings = {input_1_device, input_2_device, output_device};
  context->executeV2(bindings.data());

  cudaMemcpy(output_host.data(), output_device, num_outputs, cudaMemcpyDeviceToHost);

  for(int i = 0; i < output_host.size(); i++){
    std::cout << output_host.at(i) << " ";
  }

}

CMakeLists.txt

project(SimpleReproduction)
cmake_minimum_required(VERSION 3.1)
find_package(CUDA REQUIRED)
find_package(OpenCV REQUIRED)
include_directories(
        include
        ${CUDA_DIRECTORIES}
        ${CMAKE_SOURCE_DIR}
        ${OpenCV_INCLUDE_DIRS}
)

find_library(
        NVINFER7
        NAMES libnvinfer.so.7
)

# set CUDA_NVCC_FLAGS as you would do with CXX/C FLAGS
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -DMY_DEF=1")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMY_DEF=1 -std=c++11" )
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DMY_DEF=1" )
set(CUDA_ATTACH_VS_BUILD_RULE_TO_CUDA_FILE OFF)

cuda_add_executable(REPRO_SCRIPT github_repro.cpp)
target_link_libraries(REPRO_SCRIPT ${OpenCV_LIBS} ${NVINFER7})

How to reproduce:

python build_engine.py
mkdir build
cd build
cp ../sample.engine .
cmake ..
make
./REPRO_SCRIPT

I have also attached the scripts in a zip file.

You will notice that there is a lot of zeros in the output tensor, which seems to be wrong.

github_repro.zip (2.8 KB)

Hi @copah

This should not be the case.
I tried working with your model, and it works fine with trtexec.
As you mentioned that with Python as well you are not facing any issue, so the possibility is some issue with your cpp script.
I am checking that and will get back to you.
Thanks!

1 Like

Hi @copah,

Should be

cudaMemcpy(output_host.data(), output_device, num_outputs_bytes, cudaMemcpyDeviceToHost);

And also for inputs.

should be

cudaMemcpy(input_1_device, input_1_buffer_host.data(), num_input_bytes, cudaMemcpyHostToDevice);
cudaMemcpy(input_2_device, input_2_buffer_host.data(), num_input_bytes, cudaMemcpyHostToDevice);

hopefully this should resolve the issue.
Thanks!

Ahh, that was a stupid error. Thank you very much for looking into it!