Description
I am trying to convert a Pytorch model to TensorRT and then do inference in TensorRT using the Python API.
My model takes two inputs: left_input
and right_input
and outputs a cost_volume
. I want the batch size to be dynamic and accept either a batch size of 1 or 2.
Can I use trtexec
to generate an optimized engine for dynamic input shapes?
My current call:
trtexec \
--verbose \
--explicitBatch \
--minShapes=left_input:1x3x512x512,right_input:1x3x512x512,cost_volume:1x20x128x128 \
--optShapes=left_input:2x3x512x512,right_input:2x3x512x512,cost_volume:2x20x128x128 \
--maxShapes=left_input:2x3x512x512,right_input:2x3x512x512,cost_volume:2x20x128x128 \
--onnx='my_onnx_model' \
--saveEngine='my_trt_model' \
--workspace=3000
But I cannot get the inference in Python to work with this model. It works for a batch_size of 1, but with a batch size of 2, only the first batch is correct.
When I load the engine in Python I get:
In [1]: engine.get_binding_name(0)
Out[1]: u'left_input'
In [2]: engine.get_binding_name(1)
Out[2]: u'right_input'
In [3]: engine.get_binding_name(2)
Out[3]: u'cost_volume'
In [4]: engine.get_binding_shape(0)
Out[4]: (-1, 3, 512, 512)
In [5]: engine.get_binding_shape(1)
Out[5]: (-1, 3, 512, 512)
In [6]: engine.get_binding_shape(2)
Out[6]: (1, 20, 128, 128)
In [7]: execution_context = engine.create_execution_context()
In [8]: execution_context.get_binding_shape(0)
Out[8]: (-1, 3, 512, 512)
In [9]: execution_context.get_binding_shape(1)
Out[9]: (-1, 3, 512, 512)
In [10]: execution_context.get_binding_shape(2)
[TensorRT] ERROR: Parameter check failed at: engine.cpp::resolveSlots::1092, condition: allInputDimensionsSpecified(routine)
Out[10]: (0)
I am wondering why the output doesn’t get a dynamic batch.
Any guidance here would be appreciated.
Environment
TensorRT Version: 7.0.0-1+cuda10.0 amd64
GPU Type: GeForce GTX 1070
Nvidia Driver Version: 440.82
CUDA Version: 10.0
CUDNN Version: 7.6.5.32-1+cuda10.0 amd64
Operating System + Version: Ubuntu 18.04
Python Version (if applicable): 3.6
PyTorch Version (if applicable): 1.4.0