Triton server for squad model on P100 with TensorRT 6.0

I am using this example:

https://developer.nvidia.com/blog/nlu-with-tensorrt-bert/

With the only change that I am using TensorRT 6.0 instead of 5.1.

There are a couple of hiccups in the code. For instance the filename to download ngc is wrong, but those are minor issues.

In /workspace/TensorRT/demo/BERT i changed the CmakeLists.txt to use compute 6.0

set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}  \
--expt-relaxed-constexpr \
--expt-extended-lambda \
-gencode arch=compute_60,code=sm_60 \
-Wno-deprecated-declarations")

In the original code it looks like this:

set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}  \
--expt-relaxed-constexpr \
--expt-extended-lambda \
-gencode arch=compute_70,code=sm_70 \
-gencode arch=compute_75,code=sm_75 \
-Wno-deprecated-declarations")

after i do that i can create the engine and run the inference successfully, here is the log:

root@102a004a0db6:/workspace/TensorRT/demo/BERT# python python/bert_builder.py -m /workspace/models/fine-tuned/bert_tf_v2_base_fp16_384_v2/model.ckpt-8144 -o bert_base_384.engine -b 1 -s 384 -c /workspace/models/fine-tuned/bert_tf_v2_base_fp16_384_v2
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/dtypes.py:526: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/dtypes.py:527: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/dtypes.py:528: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/dtypes.py:529: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/dtypes.py:530: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/dtypes.py:535: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  np_resource = np.dtype([("resource", np.ubyte, 1)])
[TensorRT] INFO: Using configuration file: /workspace/models/fine-tuned/bert_tf_v2_base_fp16_384_v2/bert_config.json
[TensorRT] INFO: Found 202 entries in weight map
[TensorRT] INFO: Detected 3 inputs and 1 output network tensors.
[TensorRT] INFO: Detected 3 inputs and 1 output network tensors.
[TensorRT] INFO: Detected 3 inputs and 1 output network tensors.
[TensorRT] INFO: Saving Engine to bert_base_384.engine
[TensorRT] INFO: Done.
root@102a004a0db6:/workspace/TensorRT/demo/BERT# python python/bert_inference.py -e bert_base_384.engine -p "TensorRT is a high performance deep learning inference platform that delivers low latency and high throughput for apps such as recommenders, speech and image/video on NVIDIA GPUs. It includes parsers to import models, and plugins to support novel ops and layers before applying optimizations for inference. Today NVIDIA is open sourcing parsers and plugins in TensorRT so that the deep learning community can customize and extend these components to take advantage of powerful TensorRT optimizations for your apps." -q "What is TensorRT?" -v /workspace/models/fine-tuned/bert_tf_v2_base_fp16_384_v2/vocab.txt -b 1
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/dtypes.py:526: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/dtypes.py:527: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/dtypes.py:528: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/dtypes.py:529: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/dtypes.py:530: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/dtypes.py:535: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  np_resource = np.dtype([("resource", np.ubyte, 1)])

Passage: TensorRT is a high performance deep learning inference platform that delivers low latency and high throughput for apps such as recommenders, speech and image/video on NVIDIA GPUs. It includes parsers to import models, and plugins to support novel ops and layers before applying optimizations for inference. Today NVIDIA is open sourcing parsers and plugins in TensorRT so that the deep learning community can customize and extend these components to take advantage of powerful TensorRT optimizations for your apps.

Question: What is TensorRT?

Running Inference...
------------------------
Running inference in 85.506 Sentences/Sec
------------------------
Processing output 0 in batch
Answer: 'a high performance deep learning inference platform'
With probability: 63.446

Now the issue is when I try to load this model in the triton server, I get a version error;

I create the following config.pbtxt file (I am using batch size 1 for testing):

name: "squad_trt_1"
platform: "tensorrt_plan"
#max_batch_size: 1
input [
{
name: "segment_ids"
data_type: TYPE_INT32
dims: [384,1]
},
{
name: "input_ids"
data_type: TYPE_INT32
dims: [384,1]
},
{
name: "input_mask"
data_type: TYPE_INT32
dims: [384,1]
}
]
output [
{
name: "cls_squad_logits"
data_type: TYPE_FP32
dims: [-1,-1,2,1,1]
}
]
instance_group {
kind: KIND_GPU
count: 1
}

This is the outlook of my models directory:

models/
└── squad_trt_1
    ├── 1
    │   └── model.plan
    └── config.pbtxt

finally if I run the triton command I get this:

root@nvidia-ngc-image-3-vm:~/src/TensorRT/demo/BERT/models# docker run --gpus=all --rm --shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 -p8000:8000 -p8001:8001 -p8002:8002 -v /home/fciannel/src/TensorRT/demo/BERT/models:/models nvcr.io/nvidia/tritonserver:20.03.1-py3 tritonserver --model-repository=/models

=============================
== Triton Inference Server ==
=============================

NVIDIA Release 20.03.1 (build 12830698)

Copyright (c) 2018-2020, NVIDIA CORPORATION.  All rights reserved.

Various files include modifications (c) NVIDIA CORPORATION.  All rights reserved.
NVIDIA modifications are covered by the license terms that apply to the underlying
project or file.

2020-06-23 22:27:35.117066: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcudart.so.10.2
I0623 22:27:38.722028 1 metrics.cc:164] found 1 GPUs supporting NVML metrics
I0623 22:27:38.727533 1 metrics.cc:173]   GPU 0: Tesla P100-PCIE-16GB
I0623 22:27:38.728670 1 server.cc:130] Initializing Triton Inference Server
I0623 22:27:39.246667 1 server_status.cc:55] New status tracking for model 'squad_trt_1'
I0623 22:27:39.246778 1 model_repository_manager.cc:723] loading: squad_trt_1:1
W0623 22:27:40.732369 1 metrics.cc:276] failed to get energy consumption for GPU 0, NVML_ERROR 3
W0623 22:27:42.735866 1 metrics.cc:276] failed to get energy consumption for GPU 0, NVML_ERROR 3
W0623 22:27:44.739293 1 metrics.cc:276] failed to get energy consumption for GPU 0, NVML_ERROR 3
E0623 22:28:42.251715 1 logging.cc:43] ../rtSafe/coreReadArchive.cpp (38) - Serialization Error in verifyHeader: 0 (Version tag does not match)
E0623 22:28:42.315513 1 logging.cc:43] INVALID_STATE: std::exception
E0623 22:28:42.315572 1 logging.cc:43] INVALID_CONFIG: Deserialize the cuda engine failed.
E0623 22:28:42.332637 1 model_repository_manager.cc:891] failed to load 'squad_trt_1' version 1: Internal: unable to create TensorRT engine
error: creating server: Internal - failed to load all models

I understand from the release notes that triton 20.3.1 supports compute capability 6.0 and higher, so why is the model failing to load?

If I run a previous version of the inference engine, I get the following:

root@nvidia-ngc-image-3-vm:~/src/TensorRT/demo/BERT/models# docker run --gpus=all --rm --shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 -p8000:8000 -p8001:8001 -p8002:8002 -v /home/fciannel/src/TensorRT/demo/BERT/models:/models nvcr.io/nvidia/tensorrtserver:19.02-py3 trtserver --model-store=/models

===============================
== TensorRT Inference Server ==
===============================

NVIDIA Release 19.02 (build 5627847)

Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
Copyright 2018 The TensorFlow Authors.  All rights reserved.
Copyright 2018 The TensorFlow Serving Authors.  All rights reserved.
Copyright (c) 2016-present, Facebook Inc. All rights reserved.

Various files include modifications (c) NVIDIA CORPORATION.  All rights reserved.
NVIDIA modifications are covered by the license terms that apply to the underlying
project or file.
Failed to detect NVIDIA driver version.

I0623 22:41:56.533124 1 server.cc:761] Initializing TensorRT Inference Server
I0623 22:41:56.533238 1 server.cc:811] Reporting prometheus metrics on port 8002
I0623 22:41:56.547854 1 metrics.cc:150] found 1 GPUs supporting NVML metrics
I0623 22:41:56.553304 1 metrics.cc:159]   GPU 0: Tesla P100-PCIE-16GB
I0623 22:41:56.554668 1 server.cc:1195] Starting server 'inference:0' listening on
I0623 22:41:56.554694 1 server.cc:1199]  localhost:8001 for gRPC requests
I0623 22:41:56.554841 1 server.cc:1103] Building nvrpc server
I0623 22:41:56.554875 1 server.cc:1109] Register TensorRT GRPCService
I0623 22:41:56.555796 1 server.cc:1112] Register Infer RPC
I0623 22:41:56.555824 1 server.cc:1116] Register Status RPC
I0623 22:41:56.555833 1 server.cc:1120] Register Profile RPC
I0623 22:41:56.555841 1 server.cc:1124] Register Health RPC
I0623 22:41:56.555849 1 server.cc:1128] Register Executor
I0623 22:41:56.563500 1 server.cc:1209]  localhost:8000 for HTTP requests
[warn] getaddrinfo: address family for nodename not supported
[evhttp_server.cc : 237] RAW: Entering the event loop ...
I0623 22:41:56.602221 1 server_status.cc:105] New status tracking for model 'squad_trt_1'
I0623 22:41:56.628283 1 server_core.cc:465] Adding/updating models.
I0623 22:41:56.628312 1 server_core.cc:562]  (Re-)adding model: squad_trt_1
I0623 22:41:56.728639 1 basic_manager.cc:739] Successfully reserved resources to load servable {name: squad_trt_1 version: 1}
I0623 22:41:56.728672 1 loader_harness.cc:66] Approving load for servable version {name: squad_trt_1 version: 1}
I0623 22:41:56.728695 1 loader_harness.cc:74] Loading servable version {name: squad_trt_1 version: 1}
I0623 22:41:57.629418 1 plan_bundle.cc:202] Creating instance squad_trt_1_0_0_gpu0 on GPU 0 (6.0) using model.plan
E0623 22:41:58.556772 1 metrics.cc:239] failed to get energy consumption for GPU 0, NVML_ERROR 3
E0623 22:42:00.560240 1 metrics.cc:239] failed to get energy consumption for GPU 0, NVML_ERROR 3
E0623 22:42:02.563626 1 metrics.cc:239] failed to get energy consumption for GPU 0, NVML_ERROR 3
E0623 22:42:04.566993 1 metrics.cc:239] failed to get energy consumption for GPU 0, NVML_ERROR 3
E0623 22:42:06.570363 1 metrics.cc:239] failed to get energy consumption for GPU 0, NVML_ERROR 3
E0623 22:42:08.579420 1 metrics.cc:239] failed to get energy consumption for GPU 0, NVML_ERROR 3
E0623 22:42:10.582825 1 metrics.cc:239] failed to get energy consumption for GPU 0, NVML_ERROR 3
E0623 22:42:12.586202 1 metrics.cc:239] failed to get energy consumption for GPU 0, NVML_ERROR 3
E0623 22:42:14.595283 1 metrics.cc:239] failed to get energy consumption for GPU 0, NVML_ERROR 3
E0623 22:42:16.598764 1 metrics.cc:239] failed to get energy consumption for GPU 0, NVML_ERROR 3
E0623 22:42:18.602144 1 metrics.cc:239] failed to get energy consumption for GPU 0, NVML_ERROR 3
E0623 22:42:20.611196 1 metrics.cc:239] failed to get energy consumption for GPU 0, NVML_ERROR 3
E0623 22:42:22.614621 1 metrics.cc:239] failed to get energy consumption for GPU 0, NVML_ERROR 3
trtserver: engine.cpp:868: bool nvinfer1::rt::Engine::deserialize(const void*, std::size_t, nvinfer1::IGpuAllocator&, nvinfer1::IPluginFactory*): Assertion `size >= bsize && "Mismatch between allocated memory size and expected size of serialized engine."' failed.

what’s the best way for me to have this running on a P100 GPU? It’s all fine if I run on a V100 with tensorrt 7.0.

UPDATE

I manage to get this working by setting the CUDA capability to 60 (somehow the docker build had not worked before so I had another run with --no-cache), and by loading the plugins that are created at build time. This is the command I used:

docker run --gpus=all --rm --shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 -p8000:8000 -p8001:8001 -p8002:8002 -v /home/fciannel/src/TensorRT/demo/BERT/models:/models -eLD_PRELOAD=/models/libbert_plugins.so:/models/libcommon.so nvcr.io/nvidia/tensorrtserver:19.09-py3 trtserver --model-store=/models

where libbert_plugins.so and libcommon.so are the libraries created inside the build directory of the container.

Also the config.pbtxt file looks like this:

name: "squad_trt_1"
platform: "tensorrt_plan"
#max_batch_size: 1
input [
{
name: "segment_ids"
data_type: TYPE_INT32
dims: [1,384]
},
{
name: "input_ids"
data_type: TYPE_INT32
dims: [1,384]
},
{
name: "input_mask"
data_type: TYPE_INT32
dims: [1,384]
}
]
output [
{
name: "cls_dense"
data_type: TYPE_FP32
dims: [1,384,2,1,1]
}
]
instance_group {
kind: KIND_GPU
count: 1
}