TensorRT inference take too much time than expected

I am using tensorrt to deploy model on windows10, but I find the inference time is much longer than expected. Here is how I convert the model and do the inference.

First, I use pytorch to train my model, then export to onnx model, then use onnxsim to simplify it (using dynamic-axis):

    inputs = ["input"]
    outputs = ["output"]
    dynamic_axes = {"input": {0: 'batch'}, "output": {0: 'batch'}}
    torch.onnx.export(pth_model,
                        x,
                        export_onnx_file,
                        opset_version=12,
                        # do_constant_folding=False,
                        input_names=inputs,		
                        output_names=outputs,	
                        verbose=False,
                        dynamic_axes=dynamic_axes,
                        # example_outputs=torch.randn(1, 3, 192, 2048),
                        operator_export_type=torch.onnx.OperatorExportTypes.ONNX
                        )

    onnx_model = onnx.load(export_onnx_file)
    onnx.checker.check_model(onnx_model, full_check=True)
    
    if args.print:
        print(onnx.helper.printable_graph(onnx_model.graph))

    # onnx_sim_model, flag = onnxsim.simplify(onnx_model)
    onnx_sim_model, flag = onnxsim.simplify(onnx_model, input_shapes={"input":(batch_size, 1, 192, 2048)})
    onnx.save_model(onnx_sim_model, export_onnxsim_file)

Then, I convert the simplified onnx model to tensorrt plan, by following this toturial: https://developer.nvidia.com/blog/speeding-up-deep-learning-inference-using-tensorrt/

nvinfer1::ICudaEngine* SpiderONNXModelLoader::CreateCudaEngine(string const& onnxModelPath, int batchSize)
{
	unique_ptr<nvinfer1::IBuilder, Destroy<nvinfer1::IBuilder>> builder{ nvinfer1::createInferBuilder(g_logger_) };
	const auto explicitBatch = 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
	unique_ptr<nvinfer1::INetworkDefinition, Destroy<nvinfer1::INetworkDefinition>> network{ builder->createNetworkV2(explicitBatch) };
	unique_ptr<nvonnxparser::IParser, Destroy<nvonnxparser::IParser>> parser{ nvonnxparser::createParser(*network, g_logger_) };
	unique_ptr<nvinfer1::IBuilderConfig, Destroy<nvinfer1::IBuilderConfig>> config{ builder->createBuilderConfig() };

	if (!parser->parseFromFile(onnxModelPath.c_str(), static_cast<int>(ILogger::Severity::kINFO)))
	{
		qWarning() << "ERROR: could not parse input engine.";
		return nullptr;
	}
	builder->setMaxBatchSize(batchSize);
	config->setMaxWorkspaceSize((1 << 60));
	config->setFlag(BuilderFlag::kFP16);
	//builder->setFp16Mode(builder->platformHasFastFp16());

	auto profile = builder->createOptimizationProfile();
	profile->setDimensions(network->getInput(0)->getName(), OptProfileSelector::kMIN, Dims4{ 1, 1, 192 , 2048 });
	profile->setDimensions(network->getInput(0)->getName(), OptProfileSelector::kOPT, Dims4{ 1, 1, 192 , 2048 });
	profile->setDimensions(network->getInput(0)->getName(), OptProfileSelector::kMAX, Dims4{ batchSize, 1, 192 , 2048 });
	config->addOptimizationProfile(profile);

	return builder->buildEngineWithConfig(*network, *config);
}

The comparision of time consumption is between libtorch and tensorrt C++ (same input size: batch: 19, channels: 1, height: 192, width: 2048):
libtorch: around 60ms
tensorrt: around 140ms

    # code for tensorrt
	qint64 begin3 = QDateTime::currentMSecsSinceEpoch();
	bool flag1 = context->enqueueV2(bindings, stream, nullptr);
	cudaStreamSynchronize(stream);
	qWarning() << QString::fromStdString("Enqueue") + " Time:" << QString::number(
		QDateTime::currentMSecsSinceEpoch() - begin3);

    # code for libtorch
    qint64 begin3 = QDateTime::currentMSecsSinceEpoch();
    tensor_out = pt_model.forward(tensor_inputs).toTensor();
	c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream();
	AT_CUDA_CHECK(cudaStreamSynchronize(stream));
	qWarning() << QString::fromStdString("Inference") + " Time:" << QString::number(
		QDateTime::currentMSecsSinceEpoch() - begin3);

Environment:
pytorch/libtorch: 1.5
tnesorrt: 7.2.1
cuda: 10.2
cudnn: 8.0.4
onnx: 1.7
GPU: RTX 2060

The model files can be found here: google drive

Thanks a lot

Hi @peter0431,
Can you try profile optimization and see if there are any improvements?
https://docs.nvidia.com/deeplearning/tensorrt/best-practices/index.html

Thanks!

Hi @AakankshaS, I have tried different models, and the results are pretty confusing. This time, I use trtexec for conversion and inference. The results is shown below:

The conversion code is: (shape is different accoring to the table above)

trtexec --workspace=2048 --onnx=segmentation_1123.onnx --minShapes=input:1x3x192x2048 --optShapes=input:1x3x192x2048 --maxShapes=input:19x3x192x2048 --saveEngine=segmentation_1123.trt --best --shapes=input:1x3x192x2048

Besides, since bisenet v2 and fastscnn are more recent and have less parameters compare to bisenet v1, I don’t understand why bisenet v1 consumes less time than bisenet v2 and fastscnn. Thank you.