the ONNX model download link: rfdn_asx4_nf64nm2inc3_calibrated_op10.onnx - Google Drive
the script I use, modified from tutorial of pytorch-quantization tool:
Quantization
def collect_stats(model, data_loader, num_batches):
"""Feed data to the network and collect statistic"""
# Enable calibrators
for name, module in model.named_modules():
if isinstance(module, quant_nn.TensorQuantizer):
if module._calibrator is not None:
module.disable_quant()
module.enable_calib()
else:
module.disable()
for i, (lr, hr, filename) in tqdm(enumerate(data_loader), total=num_batches):
model(lr.cuda())
if i >= num_batches:
break
# Disable calibrators
for name, module in model.named_modules():
if isinstance(module, quant_nn.TensorQuantizer):
if module._calibrator is not None:
module.enable_quant()
module.disable_calib()
else:
module.enable()
def compute_amax(model, **kwargs):
# Load calib result
for name, module in model.named_modules():
if isinstance(module, quant_nn.TensorQuantizer):
if module._calibrator is not None:
if isinstance(module._calibrator, calib.MaxCalibrator):
module.load_calib_amax()
else:
module.load_calib_amax(**kwargs)
print(F"{name:40}: {module}")
model.cuda()
quant_modules.initialize()
quant_desc_input = QuantDescriptor(calib_method='histogram')
quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input)
quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input)
model = RFDN_ASX4_nf64m2()
pretrain_model = 'checkpoints/rfdn_asx4_nf64nm2inc3.pt'
state_dict = torch.load(pretrain_model)
model.load_state_dict(state_dict, strict=True)
model.cuda()
loader = Data(args)
test_loader = loader.loader_test
train_loader = loader.loader_train
# quantize
with torch.no_grad():
collect_stats(model, train_loader, num_batches=10)
compute_amax(model, method="percentile", percentile=99.99)
torch.save(model.state_dict(), pretrain_model.replace('.pt', '_calibrated.pt'))
# export onnx
quant_nn.TensorQuantizer.use_fb_fake_quant = True
dummy_input = torch.zeros(1, 3, 2160, 3840, requires_grad=False).cuda()
torch.set_grad_enabled(False)
# enable_onnx_checker needs to be disabled.
torch.onnx.export(model,
dummy_input,
'rfdn_calibrated_op10.onnx',
verbose=True,
input_names=['input'],
output_names=['output'],
opset_version=10,
enable_onnx_checker=False)
Build TensorRT engine
def get_engine(onnx_file_path, engine_file_path=""):
"""Attempts to load a serialized engine if available, otherwise builds a new TensorRT engine and saves it."""
def build_engine():
"""Takes an ONNX file and creates a TensorRT engine to run inference with"""
network_creation_flag = int('11', 2) # EXPLICIT_BATCH | EXPLICIT_PRECISION
with trt.Builder(TRT_LOGGER) as builder, builder.create_network(
network_creation_flag) as network, trt.OnnxParser(network, TRT_LOGGER) as parser:
builder.max_workspace_size = common.GiB(2)
builder.max_batch_size = 1
builder.int8_mode = True
builder.int8_calibrator = None
# Parse model file
if not os.path.exists(onnx_file_path):
print(
'ONNX file {} not found, please run yolov3_to_onnx.py first to generate it.'.format(onnx_file_path))
exit(0)
print('Loading ONNX file from path {}...'.format(onnx_file_path))
with open(onnx_file_path, 'rb') as model:
print('Beginning ONNX file parsing')
if not parser.parse(model.read()):
print('ERROR: Failed to parse the ONNX file.')
for error in range(parser.num_errors):
print(parser.get_error(error))
return None
print('Completed parsing of ONNX file')
print('Building an engine from file {}; this may take a while...'.format(onnx_file_path))
engine = builder.build_cuda_engine(network)
print("Completed creating Engine")
with open(engine_file_path, "wb") as f:
f.write(engine.serialize())
return engine
if os.path.exists(engine_file_path):
# If a serialized engine exists, use it instead of building an engine.
print("Reading engine from file {}".format(engine_file_path))
with open(engine_file_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
return runtime.deserialize_cuda_engine(f.read())
else:
return build_engine()
H, W, C = 2160, 3840, 3
scale = 1
with get_engine('rfdn_calibrated_op10.onnx', 'rfdn_calibrated_op10.engine') as engine, engine.create_execution_context() as context:
inputs, outputs, bindings, stream, _ = common.allocate_buffers(engine,
input_shapes=(1, 3, H, W),
output_shapes=(1, 3, scale * H, scale * W))
inputs[0].host = input_image
trt_outputs = common.do_inference_v2(context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream)
I ran the snippet u provide, It give the following error:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/usr/local/lib/python3.6/dist-packages/onnx/checker.py", line 91, in check_model
C.check_model(model.SerializeToString())
onnx.onnx_cpp2py_export.checker.ValidationError: Unrecognized attribute: axis for operator QuantizeLinear
==> Context: Bad node spec: input: "fea_conv.0.weight" input: "147" input: "643" output: "150" name: "QuantizeLinear_7" op_type: "QuantizeLinear" attribute { name: "axis" i: 0 type: INT }
besides, I test the resnet50 model used in the pytorch-quantization tutorial, it gives the similar error:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/usr/local/lib/python3.6/dist-packages/onnx/checker.py", line 91, in check_model
C.check_model(model.SerializeToString())
onnx.onnx_cpp2py_export.checker.ValidationError: Unrecognized attribute: axis for operator QuantizeLinear
==> Context: Bad node spec: input: "conv1.weight" input: "435" input: "1198" output: "438" name: "QuantizeLinear_7" op_type: "QuantizeLinear" attribute { name: "axis" i: 0 type: INT }