import time import os import sys import logging import torch import tensorrt as trt from pytorch_quantization import calib from pytorch_quantization import nn as quant_nn from pytorch_quantization.tensor_quant import QuantDescriptor def replace_abn_with_batchnorm(model): from pytorch_quantization.tensor_quant import QuantDescriptor for n, m in model.named_modules(): if isinstance(m, InPlaceABN) or isinstance(m, InPlaceABNSync): if m.activation == 'leaky_relu': new_bn = nn.Sequential(nn.BatchNorm2d(m.num_features, momentum=m.momentum, eps=m.eps, track_running_stats=m.track_running_stats), nn.ReLU(inplace=True)) new_bn[0].bias.data = m.bias.data new_bn[0].weight.data = torch.abs(m.weight.data) + m.eps new_bn[0].weight.requires_grad = m.weight.requires_grad new_bn[0].register_buffer('running_mean', m.running_mean) new_bn[0].register_buffer('running_var', m.running_var) elif m.activation == 'identity': new_bn = nn.BatchNorm2d(m.num_features, momentum=m.momentum, eps=m.eps, affine=m.affine, track_running_stats=m.track_running_stats) new_bn.bias.data = m.bias.data new_bn.weight.data = torch.abs(m.weight.data) + m.eps new_bn.weight.requires_grad = m.weight.requires_grad new_bn.register_buffer('running_mean', m.running_mean) new_bn.register_buffer('running_var', m.running_var) set_layer(model, n, new_bn) def export_ONNX(cfg): quant_nn.TensorQuantizer.use_fb_fake_quant = True opset_version = 13 print('Exporting to ONNX, model name {}'.format(lrcf.depthnet.__class__.__name__)) input_size_depthnet = [3456, 480, 3] depthnet_input_tensor = torch.randn(1, input_size_depthnet[0], input_size_depthnet[1], input_size_depthnet[2], device='cpu') depthnet = torch.load('finetuned_quantized_depthnet.pt', map_location=torch.device('cpu')) #As ONNX does not support abn blocks, search and replace them with BatchNorm2d replace_abn_with_batchnorm(depthnet) depthnet.eval() torch.onnx.export(depthnet, depthnet_input_tensor, 'quantized_depthnet.onnx', input_names=['im_data'], output_names=['depth', 'error', 'rgb_feats'], verbose=False, opset_version=opset_version, enable_onnx_checker=False) #quant_nn.TensorQuantizer.use_fb_fake_quant = False import onnxruntime as ort torch_out = depthnet(depthnet_input_tensor) torch_out = torch_out[0].detach().cpu().numpy() ort_sess = ort.InferenceSession(cfg['train_output_dir'] + cfg['onnx_DepthNet']) input_name = ort_sess.get_inputs()[0].name ort_out = ort_sess.run(None, {input_name: depthnet_input_tensor.numpy()})[0] diff = torch_out - ort_out print('MODEL MAX DIFF', diff.max()) print('MODEL MIN DIFF', diff.min())