Description
I have a simple network that takes two input batches of images and concatenates them on the channel axis, i.e.
input_1 = batch_size x 64 x 64 x 3
input_2 = batch_size x 64 x 64 x 3
output = batch_size x 64 x 64 x 6
The batch size is a dynamic parameter and varies between 1 and 49.
I create the engine in Python and serialize it. Then load it in C++ and does inference. The inference in C++ give me wrong results though. The first image in each batch is correct, but all the rest is wrong. This is not the case if I use the Python API.
Environment
TensorRT Version: 7.0.0-1+cuda10.0
GPU Type: GeForce GTX 1070
Nvidia Driver Version: 440.100
CUDA Version: 10.0
CUDNN Version: 7.6.5.32-1+cuda10.0
Operating System + Version: Ubuntu 18.04
Python Version (if applicable): 2 and 3
TensorFlow Version (if applicable):
PyTorch Version (if applicable):
Baremetal or Container (if container which image + tag):
Relevant Files
Please attach or include links to any models, data, files, or scripts necessary to reproduce your issue. (Github repo, Google Drive, Dropbox, etc.)
Steps To Reproduce
I have included a small reproducible example.
Python script to generate the sample engine.
"""
build_engine.py
"""
import tensorrt as trt
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
def main():
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
with trt.Builder(TRT_LOGGER) as builder, builder.create_network(EXPLICIT_BATCH) as network:
input_1_tensor = network.add_input(name="input_1_tensor", dtype=trt.float32, shape=(-1, 64, 64, 3))
input_2_tensor = network.add_input(name="input_2_tensor", dtype=trt.float32, shape=(-1, 64, 64, 3))
concat_layer = network.add_concatenation([input_1_tensor, input_2_tensor])
concat_layer.axis = 3
network.mark_output(concat_layer.get_output(0))
config = builder.create_builder_config()
config.max_workspace_size = 1 << 20
optim_profile = builder.create_optimization_profile()
optim_profile.set_shape("input_1_tensor", (1, 64, 64, 3), (15, 64, 64, 3), (49, 64, 64, 3))
optim_profile.set_shape("input_2_tensor", (1, 64, 64, 3), (15, 64, 64, 3), (49, 64, 64, 3))
config.add_optimization_profile(optim_profile)
engine = builder.build_engine(network, config)
with open("sample.plan", "wb") as f:
f.write(engine.serialize())
if __name__ == '__main__':
main()
C++ file to make inference
#include <opencv2/opencv.hpp>
#include <NvInfer.h>
#include <cuda_runtime.h>
class Logger : public nvinfer1::ILogger
{
void log(Severity severity, const char * msg) override
{
std::cout << msg << std::endl;
}
} gLogger;
std::string load_plan(const std::string &plan_file) {
std::ifstream planFile(plan_file);
std::stringstream planBuffer;
planBuffer << planFile.rdbuf();
std::string string_plan = planBuffer.str();
return string_plan;
}
int main() {
// Batch dimensions
int batch_size = 7;
int height = 64;
int width = 64;
int channels = 3;
// Input info
std::string plan_path = "sample.plan";
std::string input_tensor_1 = "input_1_tensor";
std::string input_tensor_2 = "input_2_tensor";
std::string loaded_plan = load_plan(plan_path);
nvinfer1::IRuntime *runtime = nvinfer1::createInferRuntime(gLogger);
// Prepare execution context
std::string plan = load_plan(plan_path);
nvinfer1::ICudaEngine *engine = runtime->deserializeCudaEngine(
loaded_plan.data(), loaded_plan.size(), nullptr);
nvinfer1::IExecutionContext *context = engine->createExecutionContext();
context->setOptimizationProfile(0);
// Get bindings
int input1BindingIndex = engine->getBindingIndex(input_tensor_1.c_str());
int input2BindingIndex = engine->getBindingIndex(input_tensor_2.c_str());
// Input and output dims
nvinfer1::Dims4 input_1_dims = nvinfer1::Dims4(batch_size, height, width, channels);
nvinfer1::Dims4 input_2_dims = nvinfer1::Dims4(batch_size, height, width, channels);
nvinfer1::Dims4
output_dims = nvinfer1::Dims4(batch_size, height, width, channels * 2);
size_t num_inputs = input_1_dims.d[0] * input_1_dims.d[1] * input_1_dims.d[2] * input_1_dims.d[3];
size_t num_outputs = output_dims.d[0] * output_dims.d[1] * output_dims.d[2] * output_dims.d[3];
std::vector<float> input_1_buffer_host;
std::vector<float> input_2_buffer_host;
std::vector<float> output_host;
input_1_buffer_host.resize(num_inputs);
input_2_buffer_host.resize(num_inputs);
output_host.resize(num_outputs);
int data_pointer = 0;
int img_vol = height * width * channels;
// Generate random batch of images
for(int i = 0; i < batch_size; i++) {
cv::Mat random_img_1(height, width, CV_32FC3, &input_1_buffer_host[data_pointer]);
cv::Mat random_img_2(height, width, CV_32FC3, &input_2_buffer_host[data_pointer]);
cv::randu(random_img_1, cv::Scalar(0, 0, 0), cv::Scalar(1, 1, 1));
cv::randu(random_img_2, cv::Scalar(0, 0, 0), cv::Scalar(1, 1, 1));
data_pointer += img_vol;
}
float *input_1_device, *input_2_device, *output_device;
size_t num_input_bytes = num_inputs * sizeof(float);
size_t num_output_bytes = num_outputs * sizeof(float);
cudaMalloc(&input_1_device, num_input_bytes);
cudaMalloc(&input_2_device, num_input_bytes);
cudaMalloc(&output_device, num_output_bytes);
std::vector<int> staticDims = {batch_size, height, width, channels};
context->setBindingDimensions(input1BindingIndex, input_1_dims);
context->setBindingDimensions(input2BindingIndex, input_2_dims);
cudaMemcpy(input_1_device, input_1_buffer_host.data(), num_inputs, cudaMemcpyHostToDevice);
cudaMemcpy(input_2_device, input_2_buffer_host.data(), num_inputs, cudaMemcpyHostToDevice);
std::vector<void*> bindings = {input_1_device, input_2_device, output_device};
context->executeV2(bindings.data());
cudaMemcpy(output_host.data(), output_device, num_outputs, cudaMemcpyDeviceToHost);
for(int i = 0; i < output_host.size(); i++){
std::cout << output_host.at(i) << " ";
}
}
CMakeLists.txt
project(SimpleReproduction)
cmake_minimum_required(VERSION 3.1)
find_package(CUDA REQUIRED)
find_package(OpenCV REQUIRED)
include_directories(
include
${CUDA_DIRECTORIES}
${CMAKE_SOURCE_DIR}
${OpenCV_INCLUDE_DIRS}
)
find_library(
NVINFER7
NAMES libnvinfer.so.7
)
# set CUDA_NVCC_FLAGS as you would do with CXX/C FLAGS
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -DMY_DEF=1")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMY_DEF=1 -std=c++11" )
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DMY_DEF=1" )
set(CUDA_ATTACH_VS_BUILD_RULE_TO_CUDA_FILE OFF)
cuda_add_executable(REPRO_SCRIPT github_repro.cpp)
target_link_libraries(REPRO_SCRIPT ${OpenCV_LIBS} ${NVINFER7})
How to reproduce:
python build_engine.py
mkdir build
cd build
cp ../sample.engine .
cmake ..
make
./REPRO_SCRIPT
I have also attached the scripts in a zip file.
You will notice that there is a lot of zeros in the output tensor, which seems to be wrong.
github_repro.zip (2.8 KB)