import numpy as np
import pycuda.driver as cuda
import pycuda.autoinit
import time
import tensorrt as trt
import sys, os
import common
# You can set the logger severity higher to suppress messages (or lower to display more messages).
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
class ModelData(object):
MODEL_FILE = "mobilenetv2-1.0.onnx"
def build_engine(model_file):
# For more information on TRT basics, refer to the introductory samples.
with trt.Builder(TRT_LOGGER) as builder, builder.create_network(EXPLICIT_BATCH) as network, trt.OnnxParser(network, TRT_LOGGER) as parser:
builder.max_workspace_size = common.GiB(2)
builder.max_batch_size = 1
#network.add_input('image_batch', trt.float32, (-1, 112, 112, 3))
# Parse the onnx network
with open(model_file, 'rb') as model:
parser.parse(model.read())
config = builder.create_builder_config()
profile = builder.create_optimization_profile()
profile.set_shape('data', (1, 3, 224, 224), (1, 3, 256, 256), (1, 3, 512, 512))
config.add_optimization_profile(profile)
return builder.build_engine(network, config)
# Loads a test case into the provided pagelocked_buffer.
def load_normalized_test_case(data_paths, pagelocked_buffer):
img = np.random.rand(1, 3, 256, 256).astype(np.float32)
pagelocked_buffer = img
return img
def main():
with build_engine(ModelData.MODEL_FILE) as engine:
# Build an engine, allocate buffers and create a stream.
# For more information on buffer allocation, refer to the introductory samples
#with open('pfld.trt', 'wb') as f:
# f.write(engine.serialize())
#return
print(engine.max_batch_size)
inputs, outputs, bindings, stream = common.allocate_buffers(engine)
with engine.create_execution_context() as context:
Ishape = load_normalized_test_case('test.jpeg', pagelocked_buffer=inputs[0].host)
#set_2nd_para(inputs[1].host)
context.active_optimization_profile = 0
context.set_binding_shape(0, (1, 3, 256, 256))
# For more information on performing inference, refer to the introductory samples.
# The common.do_inference function will return a list of outputs - we only have one in this case.
for i in range(1):
t = time.time()
[keys] = common.do_inference(context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream, batch_size=1)
print(time.time()-t)
print(keys.shape)
print(keys[:10])
if __name__ == '__main__':
main()