Description
I am trying to convert and load my custom Yolov5 pytorch model from the memory. Also, I am encrypting/ decrypting the model for some security reasons.
Is there anything wrong? I am out of the idea at this point.
Here is the steps:
1-Load the pt model
2-Convert it to onnx.
3-Convert to the engine.
4- Write serialized engine to an bytesIO object.
5- Encrypt the bytesIO object with using a secret-key and Fernet module.
6- Write encrypted model to the disk.
7- Read encrypted model, decrypt it with using the same secret-key
8- Load the decrypted bytes with using trt
9- inference
When I try encrypt/decrypt and inference with .pt model, there is no problems.
And Also I
When I print the predictions after loading the engine, there is boxes, but after applying the NMS, there is nothing left.
Here is my test.py:
import sys
sys.path.append("repos/yolov5/")
from importlib import import_module
import torch
from torch import nn
from model_files.object_detection.repos.yolov5.utils.general import check_version
class Detect(nn.Module):
# YOLOv5 Detect head for detection models
stride = None # strides computed during build
dynamic = False # force grid reconstruction
export = False # export mode
def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer
super().__init__()
self.nc = nc # number of classes
self.no = nc + 5 # number of outputs per anchor
self.nl = len(anchors) # number of detection layers
self.na = len(anchors[0]) // 2 # number of anchors
self.grid = [torch.empty(0) for _ in range(self.nl)] # init grid
self.anchor_grid = [torch.empty(0) for _ in range(self.nl)] # init anchor grid
self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2)) # shape(nl,na,2)
self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
self.inplace = inplace # use inplace ops (e.g. slice assignment)
def forward(self, x):
z = [] # inference output
for i in range(self.nl):
x[i] = self.m[i](x[i]) # conv
bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
if not self.training: # inference
if self.dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:
self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)
if isinstance(self, Segment): # (boxes + masks)
xy, wh, conf, mask = x[i].split((2, 2, self.nc + 1, self.no - self.nc - 5), 4)
xy = (xy.sigmoid() * 2 + self.grid[i]) * self.stride[i] # xy
wh = (wh.sigmoid() * 2) ** 2 * self.anchor_grid[i] # wh
y = torch.cat((xy, wh, conf.sigmoid(), mask), 4)
else: # Detect (boxes only)
xy, wh, conf = x[i].sigmoid().split((2, 2, self.nc + 1), 4)
xy = (xy * 2 + self.grid[i]) * self.stride[i] # xy
wh = (wh * 2) ** 2 * self.anchor_grid[i] # wh
y = torch.cat((xy, wh, conf), 4)
z.append(y.view(bs, self.na * nx * ny, self.no))
return x if self.training else (torch.cat(z, 1), ) if self.export else (torch.cat(z, 1), x)
def _make_grid(self, nx=20, ny=20, i=0, torch_1_10=check_version(torch.__version__, '1.10.0')):
d = self.anchors[i].device
t = self.anchors[i].dtype
shape = 1, self.na, ny, nx, 2 # grid shape
y, x = torch.arange(ny, device=d, dtype=t), torch.arange(nx, device=d, dtype=t)
yv, xv = torch.meshgrid(y, x, indexing='ij') if torch_1_10 else torch.meshgrid(y, x) # torch>=0.7 compatibility
grid = torch.stack((xv, yv), 2).expand(shape) - 0.5 # add grid offset, i.e. y = 2.0 * x - 0.5
anchor_grid = (self.anchors[i] * self.stride[i]).view((1, self.na, 1, 1, 2)).expand(shape)
return grid, anchor_grid
def import_via_name(name):
try:
module_path, class_name = name.rsplit('.', 1)
module = import_module(module_path)
return getattr(module, class_name)
except (ImportError, AttributeError) as e:
raise ImportError(name)
def yolov5_inference_test():
import cv2
import time
class DataLoader:
length = 1
#
class_str: str = "model_files.object_detection.models.YoloV5"
object_detector = import_via_name(class_str)(
weights="../../model_weights/human.pt",
half=True,
device="cuda:0",
data="model_files/object_detection/repos/yolov5/data/coco128.yaml",
imgsz=(640, 640),
dataloader=DataLoader(),
conf_thres=0.25,
repo_address="https://github.com/ultralytics/yolov5.git",
repo_branch="master"
)
import torch
# I am using following while exporting
'''object_detector.model.eval()
for k, m in object_detector.model.named_modules():
if isinstance(m, Detect):
m.inplace = inplace
m.dynamic = dynamic
m.export = True
object_detector.export_engine(object_detector.model.half(), torch.zeros(1, 3, *(640, 640)).to(object_detector.device).half(),
"../../model_weights/human.pt",
True, False, False, workspace=4, verbose=False,
prefix='TensorRT:')'''
im = cv2.imread("im.jpg")
while 1:
t = time.time()
# will be used at Dataloader following
for_batching = object_detector.preprocess(im)
batch = object_detector.batching_process([for_batching])
# use at main proc
res = object_detector.detect(batch)
# it is not completely true way for calculating the inference time on gpu due to sync. problems between gpu&cpu
print(res, time.time() - t)
yolov5_inference_test()
my inferencing function:
def detect(self, img_batch: np.ndarray) -> Union[np.ndarray, List]:
print(self.model.fp16, "IS FP16")
with torch.no_grad():
with self.dt[0]:
model_ts = torch.from_numpy(img_batch.copy()).to(self.model.device)
model_ts = model_ts.half() if self.model.fp16 else model_ts.float() # uint8 to fp16/32
model_ts /= 255 # 0 - 255 to 0.0 - 1.0
if len(model_ts.shape) == 3:
model_ts = model_ts[None] # expand for batch dim
# Inference
with self.dt[1]:
pred = self.model(model_ts, augment=False, visualize=False)
# NMS
with self.dt[2]:
pred = self.non_max_suppression(pred, self.conf_thres,
self.iou_thres, self.classes,
self.agnostic_nms, max_det=self.max_det)
return pred
def preprocess(self, img: Any) -> Any:
return self.letterbox(img, new_shape=self.imgsz, stride=self.stride, auto=False)[0][:]
@staticmethod
def batching_process(images: Any) -> Any:
im = np.stack([x for x in images]) # resize
im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW
im = np.ascontiguousarray(im) # contiguous
return im
my exporting functions:
def export_onnx(self, model, im, file, opset, dynamic, simplify, prefix='ONNX:'):
# YOLOv5 ONNX export
self.check_requirements('onnx>=1.12.0')
import onnx
import io
io_file = io.BytesIO()
print(f'\n{prefix} starting export with onnx {onnx.__version__}...', model)
file = file.split(".")[0] + '.nz.onnx'
print(file)
output_names = ['output0', 'output1'] if isinstance(model, self.SegmentationModel) else ['output0']
if dynamic:
dynamic = {'images': {0: 'batch', 2: 'height', 3: 'width'}} # shape(1,3,640,640)
if isinstance(model, self.SegmentationModel):
dynamic['output0'] = {0: 'batch', 1: 'anchors'} # shape(1,25200,85)
dynamic['output1'] = {0: 'batch', 2: 'mask_height', 3: 'mask_width'} # shape(1,32,160,160)
elif isinstance(model, self.DetectionModel):
dynamic['output0'] = {0: 'batch', 1: 'anchors'} # shape(1,25200,85)
torch.onnx.export(
model.cpu() if dynamic else model, # --dynamic only compatible with cpu
im.cpu() if dynamic else im,
io_file,
verbose=True,
opset_version=12,
do_constant_folding=True, # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False
input_names=['images'],
output_names=output_names,
dynamic_axes=dynamic or None)
# Checks
model_onnx = onnx.load_model_from_string(io_file.getvalue()) # load onnx model
print(onnx.checker.check_model(model_onnx), "MODEL ONNX") # check onnx model
# Metadata
d = {'stride': int(model.stride), 'names': model.names}
for k, v in d.items():
meta = model_onnx.metadata_props.add()
meta.key, meta.value = k, str(v)
io_file2 = io.BytesIO()
onnx.save(model_onnx, io_file2)
return file, model_onnx, io_file2
def export_engine(self, model, im, file, half, dynamic, simplify, workspace=4, verbose=False,
prefix='TensorRT:'):
# YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt
assert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`'
import tensorrt as trt
import io
import onnx
if trt.__version__[0] == '7': # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012
grid = model.model[-1].anchor_grid
model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]
_, _l, io_file2 = self.export_onnx(model, im, file, 12, dynamic, simplify) # opset 12
model.model[-1].anchor_grid = grid
else: # TensorRT >= 8
_, _l, io_file2 = self.export_onnx(model, im, file, 12, dynamic, simplify) # opset 12
# onnx = file.split(".")[0]+'.nz.onnx'
print(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
f = file.split(".")[0] + 'a.nz.engine' # TensorRT engine file
class MyLogger(trt.ILogger):
def __init__(self):
trt.ILogger.__init__(self)
logger = logger = trt.Logger(trt.Logger.INFO)
# if verbose:
# logger.min_severity = trt.Logger.Severity.VERBOSE
builder = trt.Builder(logger)
config = builder.create_builder_config()
config.max_workspace_size = workspace * 1 << 30
# config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30) # fix TRT 8.4 deprecation notice
flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
network = builder.create_network(flag)
parser = trt.OnnxParser(network, logger)
print(f, type(_l), network, network.num_layers)
io_file2.seek(0)
print(type(io_file2.getbuffer()), 9999)
print(parser.parse(io_file2.getvalue()), "PARSER")
inputs = [network.get_input(i) for i in range(network.num_inputs)]
outputs = [network.get_output(i) for i in range(network.num_outputs)]
print(f, 310000, inputs, outputs)
print("num layers:", network.num_layers)
for inp in inputs:
print(f'{prefix} input "{inp.name}" with shape{inp.shape} {inp.dtype}')
for out in outputs:
print(f'{prefix} output "{out.name}" with shape{out.shape} {out.dtype}')
if dynamic:
if im.shape[0] <= 1:
print(f'{prefix} WARNING ⚠️ --dynamic model requires maximum --batch-size argument')
profile = builder.create_optimization_profile()
for inp in inputs:
profile.set_shape(inp.name, (1, *im.shape[1:]), (max(1, im.shape[0] // 2), *im.shape[1:]), im.shape)
config.add_optimization_profile(profile)
print(f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine as {f}')
if builder.platform_has_fast_fp16 and half:
config.set_flag(trt.BuilderFlag.FP16)
print("FP16 TRUE")
engine = builder.build_engine(network, config)
print("!!!!!", engine, f)
t = io.BytesIO()
t.write(bytearray(engine.serialize()))
# t.seek(0)
crypt.save_encrypt("secret-key", t.getvalue(), f)
return f, None
crypt functions:
from cryptography.fernet import Fernet
import io
def save_encrypt(key, io_file, saving_path):
fernet = Fernet(key)
encrypted = fernet.encrypt(io_file)
with open(saving_path, 'wb') as encrypted_file:
encrypted_file.write(encrypted)
def decrypt(key, file_name):
fernet = Fernet(key)
with open(file_name, 'rb') as enc_file:
encrypted = enc_file.read()
# decrypting the file
decrypted = fernet.decrypt(encrypted)
decrypted = io.BytesIO(decrypted)
return decrypted
Loading the encrypted engine with using following code, i modified the code for decryption proccess, from the original yolov5 repo (https://github.com/ultralytics/yolov5/blob/master/models/common.py) :
elif engine: # TensorRT
LOGGER.info(f'Loading {w} for TensorRT inference...')
import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0
if device.type == 'cpu':
device = torch.device('cuda:0')
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
logger = trt.Logger(trt.Logger.INFO)
with decrypt("secret-key", w) as f, trt.Runtime(logger) as runtime:
#f.seek(0)
model = runtime.deserialize_cuda_engine(f.getbuffer())
context = model.create_execution_context()
bindings = OrderedDict()
output_names = []
fp16 = False # default updated below
dynamic = False
for i in range(model.num_bindings):
name = model.get_binding_name(i)
dtype = trt.nptype(model.get_binding_dtype(i))
print(dtype, type(dtype), "TYPE")
if model.binding_is_input(i):
if -1 in tuple(model.get_binding_shape(i)): # dynamic
dynamic = True
context.set_binding_shape(i, tuple(model.get_profile_shape(0, i)[2]))
if dtype == np.float16:
fp16 = True
print(fp16, "COMMON FP16")
else: # output
output_names.append(name)
shape = tuple(context.get_binding_shape(i))
im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device)
bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr()))
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
batch_size = bindings['images'].shape[0] # if dynamic, this is instead max batch size
Here is the outputs when export:
exporting_out.txt (162.9 KB)
it is the output when I try to inference with engine:
cuda:0 <function select_device at 0x7fabc75fa1f0> <class 'yolov5.models.common.DetectMultiBackend'>
YOLOv5 🚀 v7.0-212-g9974d51 Python-3.8.16 torch-2.0.0+cu117 CUDA:0 (NVIDIA A10G, 22564MiB)
Loading a.nz.engine for TensorRT inference...
[10/03/2023-07:44:46] [TRT] [I] Loaded engine size: 16 MiB
[10/03/2023-07:44:46] [TRT] [I] [MemUsageChange] TensorRT-managed allocation in engine deserialization: CPU +0, GPU +13, now: CPU 0, GPU 13 (MiB)
[10/03/2023-07:44:46] [TRT] [I] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +16, now: CPU 0, GPU 29 (MiB)
[10/03/2023-07:44:46] [TRT] [W] CUDA lazy loading is not enabled. Enabling it can significantly reduce device memory usage and speed up TensorRT initialization. See "Lazy Loading" section of CUDA documentation https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#lazy-loading
<class 'numpy.float16'> <class 'type'> TYPE
True COMMON FP16
<class 'numpy.float16'> <class 'type'> TYPE
<class 'numpy.float16'> <class 'type'> TYPE
<class 'numpy.float16'> <class 'type'> TYPE
<class 'numpy.float16'> <class 'type'> TYPE
True IS FP16
[tensor([], device='cuda:0', size=(0, 6))] 0.0050466060638427734
True IS FP16
[tensor([], device='cuda:0', size=(0, 6))] 0.0031447410583496094
True IS FP16
[tensor([], device='cuda:0', size=(0, 6))] 0.002671480178833008
pt model inference output:
cuda:0 <function select_device at 0x7fc212550dc0> <class 'yolov5.models.common.DetectMultiBackend'>
YOLOv5 🚀 v7.0-212-g9974d51 Python-3.8.16 torch-2.0.0+cu117 CUDA:0 (NVIDIA A10G, 22564MiB)
Fusing layers...
Model summary: 157 layers, 7015519 parameters, 0 gradients, 15.8 GFLOPs
[tensor([[4.48500e+02, 1.36750e+02, 5.39500e+02, 2.60000e+02, 9.35059e-01, 1.00000e+00],
[2.33500e+02, 1.78625e+02, 3.44000e+02, 3.15000e+02, 9.18945e-01, 1.00000e+00],
[3.41750e+02, 1.36750e+02, 4.35250e+02, 2.58750e+02, 9.16504e-01, 1.00000e+00],
[8.88750e+01, 1.90375e+02, 1.85875e+02, 3.12500e+02, 9.07715e-01, 1.00000e+00],
[2.80500e+02, 1.14500e+02, 3.50000e+02, 1.96500e+02, 9.02832e-01, 1.00000e+00],
[1.74500e+02, 1.32750e+02, 2.41500e+02, 2.27250e+02, 8.99902e-01, 1.00000e+00],
[4.06750e+02, 1.31750e+02, 6.15000e+02, 5.29000e+02, 8.38867e-01, 0.00000e+00],
[1.70250e+02, 2.30125e+02, 2.69750e+02, 3.46500e+02, 7.61230e-01, 1.00000e+00],
[3.33750e+01, 1.81750e+02, 2.44375e+02, 5.30000e+02, 7.58301e-01, 0.00000e+00],
[2.11875e+02, 1.69375e+02, 4.50500e+02, 5.31500e+02, 7.32422e-01, 0.00000e+00],
[1.10250e+02, 1.37750e+02, 3.47500e+02, 5.30000e+02, 4.63623e-01, 0.00000e+00],
[2.96750e+02, 1.26500e+02, 5.46000e+02, 5.33500e+02, 2.76367e-01, 0.00000e+00]], device='cuda:0')] 0.03249621391296387
Environment
TensorRT Version: 8.6.0
GPU Type: NVIDIA A10G (AWS EC2)
Nvidia Driver Version: 530.30.02
CUDA Version: 12.1
CUDNN Version:
Operating System + Version: Ubuntu 22.04.2 LTS (GNU/Linux 6.2.0-1012-aws x86_64)
Python Version (if applicable): 3.8.6
TensorFlow Version (if applicable):
PyTorch Version (if applicable): 2.0.0
Baremetal or Container (if container which image + tag):