from transformers import BertTokenizerFast
import os
import os
import json
import time
import threading
from typing import Tuple, List
import numpy as np
try:
import pycuda.driver as cuda
cuda.init()
import tensorrt as trt
TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE)
TRT_AVLAIBLE=True
except ImportError:
TRT_AVLAIBLE=False
try:
import onnxruntime
ONNXRUNTIME_AVLAIBLE=True
except:
ONNXRUNTIME_AVLAIBLE=False
tokenizer = BertTokenizerFast.from_pretrained("hfl/chinese-macbert-base")
def preprocess_data(text="this is a sad thing", is_trt=False):
texts = [text for _ in range(5)]
context = tokenizer(texts, padding="max_length", return_tensors='pt',max_length=128, truncation=True, return_offsets_mapping=True)
input_ids = context['input_ids'].detach().cpu().numpy()
attention_mask = context['attention_mask'].detach().cpu().numpy()
token_type_ids = context['token_type_ids'].detach().cpu().numpy()
if is_trt:
return [input_ids, attention_mask, token_type_ids]
else:
return {"input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids}
class TrtModel:
def __init__(self, model_name="detector_corrector", model_dir=".", cached_engine=True, max_batch_size=5) -> None:
self.cfx = cuda.Device(0).make_context()
self.model_dir = model_dir
self.model_name = model_name
self.max_batch_size = max_batch_size
self.catched_engine = cached_engine
self.engine = self.load_model(
os.path.join(model_dir, model_name + ".onnx"))
self.input_binding_idxs, self.output_binding_idxs = self.get_binding_idxs()
self.input_names = [self.engine.get_binding_name(
binding_idx) for binding_idx in self.input_binding_idxs]
self.output_names = [self.engine.get_binding_name(
binding_idx) for binding_idx in self.output_binding_idxs]
def __del__(self):
self.cfx.detach()
def load_model(self, model_path):
return self.load_engine(model_path)
def get_context(self):
return self.engine.create_execution_context()
def get_stream(self):
return cuda.Stream()
def predict(self, host_inputs):
self.cfx.push()
context = self.get_context()
stream = self.get_stream()
device_inputs = [cuda.mem_alloc(h_input.nbytes)
for h_input in host_inputs]
for h_input, d_input in zip(host_inputs, device_inputs):
cuda.memcpy_htod_async(d_input, h_input, stream)
host_outputs, device_outputs = self.gen_output_buffer(
host_inputs, context)
bindings = device_inputs + device_outputs
exe_res = context.execute_async_v2(
bindings=bindings, stream_handle=stream.handle)
if not exe_res:
print(f"{self.__class__.__name__} execute_async_v2 error")
for h_output, d_output in zip(host_outputs, device_outputs):
cuda.memcpy_dtoh_async(h_output, d_output, stream)
stream.synchronize()
for b in bindings:
b.free()
self.cfx.pop()
return host_outputs
def gen_output_buffer(self, host_inputs: List[np.ndarray], context):
assert context.all_binding_shapes_specified
host_outputs = []
device_outputs = []
for binding_index in self.output_binding_idxs:
output_shape = context.get_binding_shape(binding_index)
# Allocate buffers to hold output results after copying back to host
buffer = np.empty(output_shape, dtype=np.float32)
host_outputs.append(buffer)
# Allocate output buffers on device
device_outputs.append(cuda.mem_alloc(buffer.nbytes))
return host_outputs, device_outputs
def get_binding_idxs(self):
# Separate input and output binding indices for convenience
input_binding_idxs = []
output_binding_idxs = []
for binding_index in range(0, self.engine.num_bindings):
if self.engine.binding_is_input(binding_index):
input_binding_idxs.append(binding_index)
else:
output_binding_idxs.append(binding_index)
return input_binding_idxs, output_binding_idxs
def load_engine(self, onnx_file_path):
runtime = trt.Runtime(TRT_LOGGER)
cached_engine_path = os.path.join(
self.model_dir, self.model_name + ".engine")
if self.catched_engine and os.path.exists(cached_engine_path):
with open(cached_engine_path, "rb") as f:
serialized_engine = f.read()
engine = runtime.deserialize_cuda_engine(serialized_engine)
print(f"load engine from cache: {cached_engine_path} sucessfully")
return engine
EXPLICIT_BATCH = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
with trt.Builder(TRT_LOGGER) as builder, builder.create_network(EXPLICIT_BATCH) as network, trt.OnnxParser(network, TRT_LOGGER) as parser:
builder.max_batch_size = self.max_batch_size
with open(onnx_file_path, 'rb') as model:
if not parser.parse(model.read()):
for error in range(parser.num_errors):
print(parser.get_error(error))
config = builder.create_builder_config()
config.max_workspace_size = 1 << 30
# config.set_flag(trt.BuilderFlag.FP16)
serialized_engine = builder.build_serialized_network(
network, config=config)
engine = runtime.deserialize_cuda_engine(serialized_engine)
print("build engine from {} sucessfully".format(onnx_file_path))
if self.catched_engine:
with open(cached_engine_path, "wb") as f:
f.write(serialized_engine)
print(f"cached engine to: {cached_engine_path}")
return engine
class OnnxModel:
def __init__(self, model_name="detector_corrector", model_dir="."):
if not model_name.endswith(".onnx"):
model_name = model_name + ".onnx"
model_path = os.path.join(model_dir, model_name)
print(f"onnx model path is {model_path}")
self.ort_session = self.load_model(model_path)
def load_model(self, model_path):
providers = ['CUDAExecutionProvider'] # onnxruntime
# sess_options = onnxruntime.SessionOptions()
# sess_options.intra_op_num_threads = 10
# sess_options.inter_op_num_threads = 10
print(f"onnxruntime get device {onnxruntime.get_device()} available providers {onnxruntime.get_available_providers()}")
ort_session = onnxruntime.InferenceSession(
model_path, providers=providers)
print(f"onnxruntime session providers {ort_session.get_providers()}")
return ort_session
def predict(self, inputs):
ort_outs = self.ort_session.run(None, inputs)
return ort_outs
trt_model = TrtModel()
onnx_model = OnnxModel()
trt_output = trt_model.predict(preprocess_data(is_trt=True))
onnx_output = onnx_model.predict(preprocess_data(is_trt=False))
trt_detector_logits, trt_corrector_logits = trt_output
onnx_detector_logits, onnx_corrector_logits = onnx_output
import numpy as np
assert np.allclose(trt_detector_logits, onnx_detector_logits)
assert np.allclose(trt_corrector_logits, onnx_corrector_logits)
if you run this code, np.allclose result always False