[TensorRT] ERROR: Parameter check failed at: …/builder/Network.cpp::addInput::671, condition: isValidDims(dims, hasImplicitBatchDimension())
#!/bin/bash
INPUTs=input_images:0
OUTPUTs=feature_fusion/Conv_11/Sigmoid:0,feature_fusion/Conv_12/Sigmoid:0
PB_PATH=./models/resnet50_30w_nchw_no_is_training.pb
ONNX_PATH=./models/resnet50_30w_nchw_no_is_training.onnx
python3 -m tf2onnx.convert \
--input $PB_PATH \
--output $ONNX_PATH \
--inputs $INPUTs \
--outputs $OUTPUTs \
--fold_const \
--opset 10 \
--verbose
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import warnings
warnings.filterwarnings(action='ignore', category=FutureWarning)
#-------------------------------------------------------------------------------
import tensorrt as trt
import common
from get_data import ModelData
################################################################################
TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE)
PB_PATH = './models/resnet50_30w_nchw_no_is_training.pb'
ONNX_PATH = './models/resnet50_30w_nchw_no_is_training.onnx'
ENGINE_PATH = './models/resnet50_30w_nchw_no_is_training.engine'
################################################################################
if __name__ == '__main__':
flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
with trt.Builder(TRT_LOGGER) as builder, \
builder.create_network(flag) as network, \
builder.create_builder_config() as config, \
trt.OnnxParser(network, TRT_LOGGER) as parser:
# inp = network.add_input(
# name=ModelData.INPUT_NAME,
# dtype=ModelData.INPUT_DTYPE,
# shape=ModelData.INPUT_SHAPE)
with open(ONNX_PATH, 'rb') as f: parser.parse(f.read())
# builder.max_batch_size = ModelData.MAX_BATCH
# builder.max_workspace_size = common.GiB(100)
# profile = builder.create_optimization_profile()
# profile.set_shape(
# ModelData.INPUT_NAME,
# ModelData.MIN_INPUT_SHAPE,
# ModelData.OPT_INPUT_SHAPE,
# ModelData.MAX_INPUT_SHAPE)
# config.add_optimization_profile(profile)
# engine = builder.build_engine(network, config)
# with open(ENGINE_PATH, 'wb') as f: f.write(engine.serialize())