Tensorflow model acceleration on AGX

folks got a poor 2 fps for inferencing of the input image (1920x1080) using tensorflow 1.15

given the need is to increase the fps at least to 6fps,

what direction could you recommend?

I can see multiple discussions in this regard including but not limited to suggesting converting to onnx first

https://github.com/onnx/tensorflow-onnx

then measuring performance with

/usr/src/tensorrt/bin/trtexec --onnx=[your/file]

Could you share directions in regards which methods to try. Thanks

relevant discussions:

You can use tflite to speed up.

However, I do not know if it will be 3 times faster.
Since I don’t know what model it is, I can’t say if it can be converted to TensorRT.
TF-TRT might be able to implement it, but whether it is faster or not depends on where the bottlenecks are.
In my experience, most of the bottlenecks are in pre-processing and post-processing.

1 Like

Hi,

Which model do you use?

ONNX flow (tf-onnx-trt) is verified for the model trained with TensorFlow 2.
If v1.15 is used, please try to convert it into the uff format instead (tf-uff-trt).

Below is a sample for your reference:

Thanks

It was said to be based on GoogleNet [ however in sample execution python file I see keras/ ResNet50 mentionied ] which I got the frozen graph as a .pb file of. Also got meta/ index/ checkpoint / data files .

Given that I can not use provided in the sample step directly:

sample_uff_mnist --datadir $TRT_DATADIR/mnist

right?
But should rather instead try to open python, then somehow import tensorflow, uff, then execute the code below?:

import tensorflow as tf
import uff
uff.from_tensorflow_frozen_model(frozen_file="frozen_file.pb", preprocessor=None, output_filename="test.uff")

Is that what you meant by your post? Or rather a python sample file which converts MNIST to UFF should be used?
Could you extend, please?

Do you suggest that I build the tensorRT from sources on Jetson ? from GitHub - NVIDIA/TensorRT: NVIDIA® TensorRT™, an SDK for high-performance deep learning inference, includes a deep learning inference optimizer and runtime that delivers low latency and high throughput for inference applications. ?
Or the apt version has the samples included already built?
Or these specific steps are required, if so does it mean samples are not installable from apt? Or the samples requirement doesn’t apply to the way you meant by your post?

Select the platform and target OS (example: Jetson AGX Xavier, Linux Jetpack 5.0), and click Continue.
Under Download & Install Options change the download folder and select Download now, Install later. Agree to the license terms and click Continue.
Move the extracted files into the <TensorRT-OSS>/docker/jetpack_files folder.

However, by executing the python code listed above I got the test.uff output file, as it seems.

Also I can use list of the layers or conversion with

python3 /usr/lib/python3.8/dist-packages/uff/bin/convert_to_uff.py frozen_file.pb -l
python3 /usr/lib/python3.8/dist-packages/uff/bin/convert_to_uff.py frozenfile_.pb -o output_file.uff

So how to execute it from test.uff?
Do you suggest that I need to run "sample_uff_mnist " ? some other specific code?
Thank you
AV

You may be able to convert PB to ONNX with tf2onnx.convert.
tf2onnx is compatible with Tensorflow 1.13-1.15, 2.1-2.9.

–inputs and --outputs should be input node name and output node name of your model.
Example:

# install
pip3 install tf2onnx

# PB to ONNX
time python3 -m tf2onnx.convert --input frozen_inference_graph.pb --output model.onnx --opset 15 --inputs image_tensor:0 --outputs detection_classes:0,num_detections:0,detection_masks:0,detection_boxes:0,detection_scores:0

I don’t know your PB model, so I converted this onnx model to tensorrt model.

wget https://github.com/onnx/models/raw/main/vision/classification/inception_and_googlenet/googlenet/model/googlenet-12.onnx
time /usr/src/tensorrt/bin/trtexec --onnx=googlenet-12.onnx --saveEngine=googlenet-12_fp32.engine

Successfully converted with Jetson Nano 4GB JetPack 4.6.1.
log-convert-googlenet-onnx_to_engine.txt (18.3 KB)

@naisy Thank you for following up

time python3 -m tf2onnx.convert --input frozen_inference_graph.pb --output model.onnx --opset 15 --inputs image_tensor:0 --outputs detection_classes:0,num_detections:0,detection_masks:0,detection_boxes:0,detection_scores:0

This step doesn’t seem to get through

ValueError: NodeDef mentions attr 'explicit_paddings' not in Op<name=MaxPool; signature=input:T -> output:T; attr=T:type,default=DT_FLOAT,allowed=[DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64, DT_UINT8, DT_INT16, DT_INT8, DT_UINT16, DT_QINT8]; attr=ksize:list(int),min=4; attr=strides:list(int),min=4; attr=padding:string,allowed=["SAME", "VALID"]; attr=data_format:string,default="NHWC",allowed=["NHWC", "NCHW", "NCHW_VECT_C"]>; NodeDef: {{node maxpool0}}. (Check whether your GraphDef-interpreting binary is up to date with your GraphDef-generating binary.).

probably it is because although the model was created with tensorflow 1.15.4 it was not specifically created with 1.15.5 which I got installed, according to what I can see in github issues the error could be because of that

convertinhg to TFlite seems to fail with same probably versioning error

I converted the pb file you sent me to onnx.
First I needed to check the input node and output node.

check inputs/outputs from frozen_graph.pb

(on Jetson Nano 4GB JetPack 4.6.1 Tensorflow 1.15.5)

import tensorflow as tf

# Load frozen graph
graph_def = tf.GraphDef()
with tf.gfile.GFile("frozen_graph.pb", 'rb') as f:
    graph_def.ParseFromString(f.read())

    print('========== nodes ==========')
    nodes = [n.name + ' => ' +  n.op for n in graph_def.node]
    for node in nodes:
        print(node)

    print('========== inputs ==========')
    input_nodes = [n.name + ' => ' +  n.op for n in graph_def.node if n.op in ('Placeholder')]
    for node in input_nodes:
        print(node)

    print('========== outputs ==========')
    name_list = []
    input_list = []
    for n in graph_def.node:
        name_list.append(n.name)
        for name in n.input:
            input_list.append(name)

    outputs = set(name_list) - set(input_list)
    output_nodes = [n.name + ' => ' +  n.op for n in graph_def.node if n.name in outputs]
    for node in output_nodes:
        print(node)

Result:

========== inputs ==========
x_in => Placeholder
========== outputs ==========
decoder/mul_1 => Mul
decoder/Softmax => Softmax

Convert from PB to ONNX

(on Jetson Nano 4GB JetPack 4.6.1 Tensorflow 1.15.5)

time python -m tf2onnx.convert --input frozen_graph.pb --output model.onnx --opset 12 --inputs x_in:0 --outputs decoder/mul_1:0,decoder/Softmax:0

However, an error occurred.

ValueError: NodeDef mentions attr 'explicit_paddings' not in Op<name=MaxPool; signature=input:T -> output:T; attr=T:type,default=DT_FLOAT,allowed=[DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64, DT_UINT8, DT_INT16, DT_INT8, DT_UINT16, DT_QINT8]; attr=ksize:list(int),min=4; attr=strides:list(int),min=4; attr=padding:string,allowed=["SAME", "VALID"]; attr=data_format:string,default="NHWC",allowed=["NHWC", "NCHW", "NCHW_VECT_C"]>; NodeDef: {{node maxpool0}}. (Check whether your GraphDef-interpreting binary is up to date with your GraphDef-generating binary.).

As you point out, this appears to be a compatibility issue with the PB file.
I tried to find a Tensorflow version that could resolve this error on a Linux PC and found that this was generated with Tensorflow 2.x, not Tensorflow 1.x.

Jetson Nano 4GB JetPack 4.6.1 Tensorflow 2.7.0 was also found to be convertible.
(on Jetson Nano 4GB JetPack 4.6.1 Tensorflow 2.7.0)

time python -m tf2onnx.convert --input frozen_graph.pb --output model.onnx --opset 12 --inputs x_in:0 --outputs decoder/mul_1:0,decoder/Softmax:0

Result:

/usr/lib/python3.6/runpy.py:125: RuntimeWarning: 'tf2onnx.convert' found in sys.modules after import of package 'tf2onnx', but prior to execution of 'tf2onnx.convert'; this may result in unpredictable behaviour
  warn(RuntimeWarning(msg))
WARNING:tensorflow:From /virtualenv/python3/lib/python3.6/site-packages/tf2onnx/tf_loader.py:305: convert_variables_to_constants (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.graph_util.convert_variables_to_constants`
2022-10-06 13:49:41,472 - WARNING - From /virtualenv/python3/lib/python3.6/site-packages/tf2onnx/tf_loader.py:305: convert_variables_to_constants (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.graph_util.convert_variables_to_constants`
WARNING:tensorflow:From /virtualenv/python3/lib/python3.6/site-packages/tensorflow/python/framework/convert_to_constants.py:929: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`
2022-10-06 13:49:41,473 - WARNING - From /virtualenv/python3/lib/python3.6/site-packages/tensorflow/python/framework/convert_to_constants.py:929: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`
2022-10-06 13:49:42,876 - INFO - Using tensorflow=2.7.0, onnx=1.11.0, tf2onnx=1.12.1/b6d590
2022-10-06 13:49:42,876 - INFO - Using opset <onnx, 12>
2022-10-06 13:49:48,233 - INFO - Computed 0 values for constant folding
2022-10-06 13:49:52,516 - INFO - Optimizing ONNX model
2022-10-06 13:49:58,101 - INFO - After optimization: Cast -3 (3->0), Const -19 (146->127), Identity -2 (2->0), Reshape -1 (3->2), Transpose -144 (146->2)
2022-10-06 13:49:58,373 - INFO - 
2022-10-06 13:49:58,374 - INFO - Successfully converted TensorFlow model frozen_graph.pb to ONNX
2022-10-06 13:49:58,374 - INFO - Model inputs: ['x_in:0']
2022-10-06 13:49:58,377 - INFO - Model outputs: ['decoder/mul_1:0', 'decoder/Softmax:0']
2022-10-06 13:49:58,377 - INFO - ONNX model is saved at model.onnx

real	0m39.986s
user	0m26.892s
sys	0m3.444s

Using Jetson Nano 4GB JetPack 4.6.1 Tensorflow 2.7.0, I have successfully generated model.onnx from your frozen_graph.pb.
(Of course, this is also possible with Xavier.)

Convert ONNX to TensorRT

(on Jetson Nano 4GB JetPack 4.6.1)

time /usr/src/tensorrt/bin/trtexec --onnx=model.onnx --saveEngine=model.engine
&&&& RUNNING TensorRT.trtexec [TensorRT v8201] # /usr/src/tensorrt/bin/trtexec --onnx=model.onnx --saveEngine=model.engine
[10/06/2022-13:14:36] [I] === Model Options ===
[10/06/2022-13:14:36] [I] Format: ONNX
[10/06/2022-13:14:36] [I] Model: model.onnx
[10/06/2022-13:14:36] [I] Output:
[10/06/2022-13:14:36] [I] === Build Options ===
[10/06/2022-13:14:36] [I] Max batch: explicit batch
[10/06/2022-13:14:36] [I] Workspace: 16 MiB
[10/06/2022-13:14:36] [I] minTiming: 1
[10/06/2022-13:14:36] [I] avgTiming: 8
[10/06/2022-13:14:36] [I] Precision: FP32
[10/06/2022-13:14:36] [I] Calibration: 
[10/06/2022-13:14:36] [I] Refit: Disabled
[10/06/2022-13:14:36] [I] Sparsity: Disabled
[10/06/2022-13:14:36] [I] Safe mode: Disabled
[10/06/2022-13:14:36] [I] DirectIO mode: Disabled
[10/06/2022-13:14:36] [I] Restricted mode: Disabled
[10/06/2022-13:14:36] [I] Save engine: model.engine
[10/06/2022-13:14:36] [I] Load engine: 
[10/06/2022-13:14:36] [I] Profiling verbosity: 0
[10/06/2022-13:14:36] [I] Tactic sources: Using default tactic sources
[10/06/2022-13:14:36] [I] timingCacheMode: local
[10/06/2022-13:14:36] [I] timingCacheFile: 
[10/06/2022-13:14:36] [I] Input(s)s format: fp32:CHW
[10/06/2022-13:14:36] [I] Output(s)s format: fp32:CHW
[10/06/2022-13:14:36] [I] Input build shapes: model
[10/06/2022-13:14:36] [I] Input calibration shapes: model
[10/06/2022-13:14:36] [I] === System Options ===
[10/06/2022-13:14:36] [I] Device: 0
[10/06/2022-13:14:36] [I] DLACore: 
[10/06/2022-13:14:36] [I] Plugins:
[10/06/2022-13:14:36] [I] === Inference Options ===
[10/06/2022-13:14:36] [I] Batch: Explicit
[10/06/2022-13:14:36] [I] Input inference shapes: model
[10/06/2022-13:14:36] [I] Iterations: 10
[10/06/2022-13:14:36] [I] Duration: 3s (+ 200ms warm up)
[10/06/2022-13:14:36] [I] Sleep time: 0ms
[10/06/2022-13:14:36] [I] Idle time: 0ms
[10/06/2022-13:14:36] [I] Streams: 1
[10/06/2022-13:14:36] [I] ExposeDMA: Disabled
[10/06/2022-13:14:36] [I] Data transfers: Enabled
[10/06/2022-13:14:36] [I] Spin-wait: Disabled
[10/06/2022-13:14:36] [I] Multithreading: Disabled
[10/06/2022-13:14:36] [I] CUDA Graph: Disabled
[10/06/2022-13:14:36] [I] Separate profiling: Disabled
[10/06/2022-13:14:36] [I] Time Deserialize: Disabled
[10/06/2022-13:14:36] [I] Time Refit: Disabled
[10/06/2022-13:14:36] [I] Skip inference: Disabled
[10/06/2022-13:14:36] [I] Inputs:
[10/06/2022-13:14:36] [I] === Reporting Options ===
[10/06/2022-13:14:36] [I] Verbose: Disabled
[10/06/2022-13:14:36] [I] Averages: 10 inferences
[10/06/2022-13:14:36] [I] Percentile: 99
[10/06/2022-13:14:36] [I] Dump refittable layers:Disabled
[10/06/2022-13:14:36] [I] Dump output: Disabled
[10/06/2022-13:14:36] [I] Profile: Disabled
[10/06/2022-13:14:36] [I] Export timing to JSON file: 
[10/06/2022-13:14:36] [I] Export output to JSON file: 
[10/06/2022-13:14:36] [I] Export profile to JSON file: 
[10/06/2022-13:14:36] [I] 
[10/06/2022-13:14:36] [I] === Device Information ===
[10/06/2022-13:14:36] [I] Selected Device: NVIDIA Tegra X1
[10/06/2022-13:14:36] [I] Compute Capability: 5.3
[10/06/2022-13:14:36] [I] SMs: 1
[10/06/2022-13:14:36] [I] Compute Clock Rate: 0.9216 GHz
[10/06/2022-13:14:36] [I] Device Global Memory: 3964 MiB
[10/06/2022-13:14:36] [I] Shared Memory per SM: 64 KiB
[10/06/2022-13:14:36] [I] Memory Bus Width: 64 bits (ECC disabled)
[10/06/2022-13:14:36] [I] Memory Clock Rate: 0.01275 GHz
[10/06/2022-13:14:36] [I] 
[10/06/2022-13:14:36] [I] TensorRT version: 8.2.1
[10/06/2022-13:14:38] [I] [TRT] [MemUsageChange] Init CUDA: CPU +229, GPU +0, now: CPU 248, GPU 3340 (MiB)
[10/06/2022-13:14:38] [I] [TRT] [MemUsageSnapshot] Begin constructing builder kernel library: CPU 248 MiB, GPU 3340 MiB
[10/06/2022-13:14:38] [I] [TRT] [MemUsageSnapshot] End constructing builder kernel library: CPU 278 MiB, GPU 3369 MiB
[10/06/2022-13:14:38] [I] Start parsing network model
[10/06/2022-13:14:38] [I] [TRT] ----------------------------------------------------------------
[10/06/2022-13:14:38] [I] [TRT] Input filename:   model.onnx
[10/06/2022-13:14:38] [I] [TRT] ONNX IR version:  0.0.7
[10/06/2022-13:14:38] [I] [TRT] Opset version:    12
[10/06/2022-13:14:38] [I] [TRT] Producer name:    tf2onnx
[10/06/2022-13:14:38] [I] [TRT] Producer version: 1.12.1 b6d590
[10/06/2022-13:14:38] [I] [TRT] Domain:           
[10/06/2022-13:14:38] [I] [TRT] Model version:    0
[10/06/2022-13:14:38] [I] [TRT] Doc string:       
[10/06/2022-13:14:38] [I] [TRT] ----------------------------------------------------------------
[10/06/2022-13:14:38] [W] [TRT] onnx2trt_utils.cpp:366: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.
[10/06/2022-13:14:38] [W] [TRT] ShapedWeights.cpp:173: Weights decoder/Overfeat/ip/read:0 has been transposed with permutation of (1, 0)! If you plan on overwriting the weights with the Refitter API, the new weights must be pre-transposed.
[10/06/2022-13:14:38] [W] [TRT] ShapedWeights.cpp:173: Weights decoder/conf_ip0/read:0 has been transposed with permutation of (1, 0)! If you plan on overwriting the weights with the Refitter API, the new weights must be pre-transposed.
[10/06/2022-13:14:38] [W] [TRT] ShapedWeights.cpp:173: Weights decoder/box_ip0/read:0 has been transposed with permutation of (1, 0)! If you plan on overwriting the weights with the Refitter API, the new weights must be pre-transposed.
[10/06/2022-13:14:39] [I] Finish parsing network model
[10/06/2022-13:14:39] [I] [TRT] ---------- Layers Running on DLA ----------
[10/06/2022-13:14:39] [I] [TRT] ---------- Layers Running on GPU ----------
[10/06/2022-13:14:39] [I] [TRT] [GpuLayer] ExpandDims
[10/06/2022-13:14:39] [I] [TRT] [GpuLayer] PWN(sub/y:0 + (Unnamed Layer* 2) [Shuffle], sub)

...

[10/06/2022-13:25:28] [I] === Trace details ===
[10/06/2022-13:25:28] [I] Trace averages of 10 runs:
[10/06/2022-13:25:28] [I] Average on 10 runs - GPU latency: 698.522 ms - Host latency: 700.995 ms (end to end 701.006 ms, enqueue 3.5116 ms)
[10/06/2022-13:25:28] [I] 
[10/06/2022-13:25:28] [I] === Performance summary ===
[10/06/2022-13:25:28] [I] Throughput: 1.42652 qps
[10/06/2022-13:25:28] [I] Latency: min = 696.432 ms, max = 707.133 ms, mean = 700.995 ms, median = 700.71 ms, percentile(99%) = 707.133 ms
[10/06/2022-13:25:28] [I] End-to-End Host Latency: min = 696.439 ms, max = 707.146 ms, mean = 701.006 ms, median = 700.721 ms, percentile(99%) = 707.146 ms
[10/06/2022-13:25:28] [I] Enqueue Time: min = 3.04723 ms, max = 3.742 ms, mean = 3.5116 ms, median = 3.51733 ms, percentile(99%) = 3.742 ms
[10/06/2022-13:25:28] [I] H2D Latency: min = 2.38672 ms, max = 2.42432 ms, mean = 2.41469 ms, median = 2.41699 ms, percentile(99%) = 2.42432 ms
[10/06/2022-13:25:28] [I] GPU Compute Time: min = 693.989 ms, max = 704.663 ms, mean = 698.522 ms, median = 698.235 ms, percentile(99%) = 704.663 ms
[10/06/2022-13:25:28] [I] D2H Latency: min = 0.0561523 ms, max = 0.059082 ms, mean = 0.0576904 ms, median = 0.0576172 ms, percentile(99%) = 0.059082 ms
[10/06/2022-13:25:28] [I] Total Host Walltime: 7.01006 s
[10/06/2022-13:25:28] [I] Total GPU Compute Time: 6.98522 s
[10/06/2022-13:25:28] [I] Explanations of the performance metrics are printed in the verbose logs.
[10/06/2022-13:25:28] [I] 
&&&& PASSED TensorRT.trtexec [TensorRT v8201] # /usr/src/tensorrt/bin/trtexec --onnx=model.onnx --saveEngine=model.engine

real	10m52.601s
user	0m33.264s
sys	1m2.532s

Succeeded in generating model.engine.

TensorRT Inference

The rest of the processing can be done according to the input and output layers.

trt_classificator_av.py (11.6 KB)

(on Jetson Nano 4GB JetPack 4.6.1)

python trt_classificator_av.py --model=model.engine --image=testimage.jpg 

Result:

load: 3.990297794342041 sec
1st frame: 0.8414063453674316 sec
infer: 0.8388795852661133 sec
[array([[0.01609352, 0.01584916, 0.0152326 , ..., 0.01504145, 0.01473314,
        0.01656689],
       [0.01590273, 0.01591793, 0.01520135, ..., 0.01501078, 0.01453196,
        0.01645717],
       [0.01578423, 0.01565683, 0.01522273, ..., 0.01504697, 0.01458219,
        0.0160788 ],
       ...,
       [0.01620023, 0.01539864, 0.01554307, ..., 0.01535674, 0.01525709,
        0.01682294],
       [0.01588741, 0.0157625 , 0.01551292, ..., 0.01530198, 0.01531301,
        0.01665569],
       [0.01568917, 0.01575915, 0.01554907, ..., 0.01541131, 0.01540376,
        0.01633411]], dtype=float32), array([[-1.0533719 , -5.413399  ,  2.8164186 , -0.19156489],
       [-0.10044873, -5.237488  ,  3.139062  ,  0.03426421],
       [ 0.444525  , -4.618743  ,  2.3823085 ,  0.33535182],
       ...,
       [-1.1889715 , -3.897892  ,  3.5675561 ,  0.4611286 ],
       [-0.91229236, -3.2173855 ,  2.3442209 ,  0.83963305],
       [-1.0568575 , -3.2151484 ,  1.8535607 ,  0.69146377]],
      dtype=float32)]

(on Jetson AGX Xavier 32GB JetPack 4.6.1)

python trt_classificator_av.py --model=model.engine --image=testimage.jpg 

Result:

load: 2.4099135398864746 sec
1st frame: 0.15322327613830566 sec
infer: 0.14696931838989258 sec
[array([[0.01609352, 0.01584916, 0.0152326 , ..., 0.01504145, 0.01473314,
        0.01656689],
       [0.01590273, 0.01591793, 0.01520135, ..., 0.01501078, 0.01453196,
        0.01645717],
       [0.01578423, 0.01565683, 0.01522273, ..., 0.01504697, 0.01458219,
        0.0160788 ],
       ...,
       [0.01620023, 0.01539864, 0.01554307, ..., 0.01535674, 0.01525709,
        0.01682294],
       [0.01588741, 0.0157625 , 0.01551292, ..., 0.01530198, 0.01531301,
        0.01665569],
       [0.01568917, 0.01575915, 0.01554907, ..., 0.01541131, 0.01540375,
        0.01633411]], dtype=float32), array([[-1.0533712 , -5.4133987 ,  2.8164186 , -0.19156438],
       [-0.10044891, -5.237487  ,  3.1390605 ,  0.03426504],
       [ 0.44452453, -4.618743  ,  2.3823073 ,  0.33535177],
       ...,
       [-1.1889708 , -3.897892  ,  3.5675566 ,  0.461129  ],
       [-0.9122926 , -3.2173848 ,  2.344221  ,  0.83963376],
       [-1.0568575 , -3.2151482 ,  1.8535609 ,  0.6914637 ]],
      dtype=float32)]

Ignore the 1st frame as it is for initialization.
testimage.jpg is 4107x2743 pixels.
It takes 0.14696931838989258 sec for inference, indicating that the process takes about 6.8 fps.

I believe the pre-processing of this code is different from your code. Please rewrite it to the pre-processing you use for your inference.

1 Like

@naisy Thank you very much

Devs are asking if there is any chance to use batch processing in such case so that it will get more images processed per second, by processing them in a batch mode
Thanks

I exported ONNX from frozen_graph.pb but there was no batch information for the input node.
Have you or devs checked how much the FPS improves with more batches in tensorflow?

Folks from ML class pointed out that
batch could possibly be achieved by either using triton inference server where batch parameter should pop up somewhere, or images could be somehow figured out to get loaded into numpy array which will load them into the existing model

both of these two do not seem to require customization of the model.
However, I will try to ask the devs[suppliers of the frozen graph] but they seem not certain how to use batch parameters too.

About batch in triton, are you talking about nvmultistreamtiler?
With nvmultistreamtiler, the batch size of the model input node is 1 and it runs inference on multiple sources in parallel.

My docker is based on triton, so if you have code and config files that work in triton, I think you can quickly verify that it works.
If you need to create code, your model is a custom model, so there may not be an implementation to refer to.

The BATCH I pointed out is that the input node in the regular model is (BATCH_SIZE, H , W , C), which is 4-dimensional.
The input node in your frozen_model.pb is (H , W , C), which means that the dimension related to batch size is missing.

Normally, the input node used for training in Tensorflow is made with a placeholder to support N batches because it supports batch training.
This mechanism is also used for inference, so that N batches can be inferred at once.

In TensorRT, batch size is usually set to 1 in order to allocate memory first.
A fixed batch size may work, but today’s models are looking for latency, not fps.
If the GPU load is low and you can see a significant improvement in throughput with N batches or parallel execution, it may be worthwhile to do so.
Otherwise, I think it is better to keep it to FP16 or INT8.

I see that nvinfer and nvinferserver have batch-size, not nvmultistreamtiler. My mistake.

I measured fp16 and int8 as they can also be made.

FP16

/usr/src/tensorrt/bin/trtexec --onnx=model.onnx --saveEngine=model_fp16.engine --fp16
python trt_classificator_av.py --model=model_fp16.engine --image=testimage.jpg 
load: 2.4407081604003906 sec
1st frame: 0.09991908073425293 sec
infer: 0.0901947021484375 sec
[array([[0.01608747, 0.01584241, 0.01523666, ..., 0.0150394 , 0.01474018,
        0.01657081],
       [0.01589403, 0.01591903, 0.01520303, ..., 0.01500689, 0.01453697,
        0.01645847],
       [0.01578404, 0.01565488, 0.01522088, ..., 0.01504286, 0.0145905 ,
        0.01607842],
       ...,
       [0.0162069 , 0.01540244, 0.01553817, ..., 0.01534508, 0.01526381,
        0.01682989],
       [0.01589699, 0.01576799, 0.01550861, ..., 0.01529452, 0.01531881,
        0.01666016],
       [0.0156943 , 0.01575795, 0.01554175, ..., 0.01541713, 0.01540702,
        0.01633761]], dtype=float32), array([[-1.0683594 , -5.40625   ,  2.8320312 , -0.20666504],
       [-0.11517334, -5.2265625 ,  3.1503906 ,  0.02355957],
       [ 0.42407227, -4.6054688 ,  2.4003906 ,  0.32495117],
       ...,
       [-1.1621094 , -3.9375    ,  3.5996094 ,  0.4230957 ],
       [-0.87597656, -3.203125  ,  2.375     ,  0.79541016],
       [-1.0263672 , -3.1738281 ,  1.8730469 ,  0.6777344 ]],
      dtype=float32)]

INT8

Create int8.engine from frozen_graph.pb in Tensorflow 2.x custom model.

I used val2017 images for calibration.
As much as possible use the dataset that will be used for your inference.

# download val2017
wget http://images.cocodataset.org/zips/val2017.zip
unzip -qq val2017.zip

Your model does not have a batch_size dimension so you will need to rewrite the export code.

/home/jetson/github/TensorRT/samples/python/tensorflow_object_detection_api/image_batcher.py
/home/jetson/github/TensorRT/samples/python/tensorflow_object_detection_api/build_engine.py

image_batcher.py (8.3 KB)
build_engine.py (11.8 KB)

convert pb to trt int8

time python ~/github/TensorRT/samples/python/tensorflow_object_detection_api/build_engine.py --onnx model.onnx --engine model_int8.engine --precision int8 --calib_input=val2017 --calib_cache model_int8.calib

Result:

time python ~/github/TensorRT/samples/python/tensorflow_object_detection_api/build_engine.py --onnx model.onnx --engine model_int8.engine --precision int8 --calib_input=val2017 --calib_cache model_int8.calib
[10/07/2022-17:41:57] [TRT] [I] [MemUsageChange] Init CUDA: CPU +356, GPU +0, now: CPU 388, GPU 12053 (MiB)
[10/07/2022-17:41:57] [TRT] [I] [MemUsageSnapshot] Begin constructing builder kernel library: CPU 388 MiB, GPU 12053 MiB
[10/07/2022-17:41:58] [TRT] [I] [MemUsageSnapshot] End constructing builder kernel library: CPU 493 MiB, GPU 12158 MiB
[10/07/2022-17:41:58] [TRT] [W] onnx2trt_utils.cpp:366: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.
[10/07/2022-17:41:58] [TRT] [W] ShapedWeights.cpp:173: Weights decoder/Overfeat/ip/read:0 has been transposed with permutation of (1, 0)! If you plan on overwriting the weights with the Refitter API, the new weights must be pre-transposed.
[10/07/2022-17:41:58] [TRT] [W] ShapedWeights.cpp:173: Weights decoder/conf_ip0/read:0 has been transposed with permutation of (1, 0)! If you plan on overwriting the weights with the Refitter API, the new weights must be pre-transposed.
[10/07/2022-17:41:58] [TRT] [W] ShapedWeights.cpp:173: Weights decoder/box_ip0/read:0 has been transposed with permutation of (1, 0)! If you plan on overwriting the weights with the Refitter API, the new weights must be pre-transposed.
INFO:EngineBuilder:Network Description

...

DataType.FLOAT
INFO:EngineBuilder:Building int8 Engine in /home/jetson/data/model_int8.engine
[10/07/2022-17:41:58] [TRT] [I] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +227, GPU +232, now: CPU 748, GPU 12421 (MiB)
[10/07/2022-17:41:59] [TRT] [I] [MemUsageChange] Init cuDNN: CPU +307, GPU +305, now: CPU 1055, GPU 12726 (MiB)
[10/07/2022-17:41:59] [TRT] [I] Timing cache disabled. Turning it on will improve builder speed.
[10/07/2022-17:42:02] [TRT] [I] Detected 1 inputs and 2 output network tensors.
[10/07/2022-17:42:03] [TRT] [I] Total Host Persistent Memory: 16512
[10/07/2022-17:42:03] [TRT] [I] Total Device Persistent Memory: 0
[10/07/2022-17:42:03] [TRT] [I] Total Scratch Memory: 24509440
[10/07/2022-17:42:03] [TRT] [I] [MemUsageStats] Peak memory usage of TRT CPU/GPU memory allocators: CPU 0 MiB, GPU 271 MiB
[10/07/2022-17:42:03] [TRT] [I] [BlockAssignment] Algorithm ShiftNTopDown took 46.2998ms to assign 6 blocks to 161 nodes requiring 316293120 bytes.
[10/07/2022-17:42:03] [TRT] [I] Total Activation Memory: 316293120
[10/07/2022-17:42:03] [TRT] [I] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +0, GPU +0, now: CPU 1512, GPU 13493 (MiB)
[10/07/2022-17:42:03] [TRT] [I] [MemUsageChange] Init cuDNN: CPU +0, GPU +0, now: CPU 1512, GPU 13493 (MiB)
[10/07/2022-17:42:03] [TRT] [I] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +1, GPU +0, now: CPU 1512, GPU 13493 (MiB)
[10/07/2022-17:42:03] [TRT] [I] [MemUsageChange] Init cuDNN: CPU +0, GPU +0, now: CPU 1512, GPU 13493 (MiB)
[10/07/2022-17:42:03] [TRT] [I] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +301, now: CPU 0, GPU 333 (MiB)
[10/07/2022-17:42:03] [TRT] [I] Starting Calibration.
INFO:EngineBuilder:Calibrating image 8 / 5000
[10/07/2022-17:42:05] [TRT] [I]   Calibrated batch 0 in 1.44894 seconds.
INFO:EngineBuilder:Calibrating image 16 / 5000
[10/07/2022-17:42:06] [TRT] [I]   Calibrated batch 1 in 1.4619 seconds.
INFO:EngineBuilder:Calibrating image 24 / 5000
[10/07/2022-17:42:08] [TRT] [I]   Calibrated batch 2 in 1.4459 seconds.

...

[10/07/2022-18:00:07] [TRT] [I]   Calibrated batch 622 in 1.4218 seconds.
INFO:EngineBuilder:Calibrating image 4992 / 5000
[10/07/2022-18:00:09] [TRT] [I]   Calibrated batch 623 in 1.37199 seconds.
INFO:EngineBuilder:Calibrating image 5000 / 5000
[10/07/2022-18:00:10] [TRT] [I]   Calibrated batch 624 in 1.37575 seconds.
INFO:EngineBuilder:Finished calibration batches
[10/07/2022-18:00:48] [TRT] [I]   Post Processing Calibration data in 37.6704 seconds.
[10/07/2022-18:00:48] [TRT] [I] Calibration completed in 1130.22 seconds.
[10/07/2022-18:00:48] [TRT] [I] Writing Calibration Cache for calibrator: TRT-8201-EntropyCalibration2
INFO:EngineBuilder:Writing calibration cache data to: model_int8.calib
[10/07/2022-18:00:48] [TRT] [W] Missing scale and zero-point for tensor (Unnamed Layer* 1) [Constant]_output, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[10/07/2022-18:00:48] [TRT] [W] Missing scale and zero-point for tensor (Unnamed Layer* 149) [Constant]_output, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[10/07/2022-18:00:48] [TRT] [W] Missing scale and zero-point for tensor decoder/Reshape_3:0, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[10/07/2022-18:00:48] [TRT] [W] Missing scale and zero-point for tensor (Unnamed Layer* 161) [Shuffle]_output, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[10/07/2022-18:00:48] [TRT] [W] Missing scale and zero-point for tensor (Unnamed Layer* 162) [Softmax]_output, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[10/07/2022-18:00:48] [TRT] [W] Missing scale and zero-point for tensor (Unnamed Layer* 168) [Constant]_output, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[10/07/2022-18:00:48] [TRT] [I] ---------- Layers Running on DLA ----------
[10/07/2022-18:00:48] [TRT] [I] ---------- Layers Running on GPU ----------

...

[10/07/2022-18:00:48] [TRT] [I] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +0, GPU +0, now: CPU 1582, GPU 15006 (MiB)
[10/07/2022-18:00:48] [TRT] [I] [MemUsageChange] Init cuDNN: CPU +1, GPU +0, now: CPU 1583, GPU 15006 (MiB)
[10/07/2022-18:00:48] [TRT] [I] Local timing cache in use. Profiling results in this builder pass will not be stored.
[10/07/2022-18:05:22] [TRT] [I] Some tactics do not have sufficient workspace memory to run. Increasing workspace size may increase performance, please check verbose output.
[10/07/2022-18:07:49] [TRT] [I] Detected 1 inputs and 2 output network tensors.
[10/07/2022-18:07:49] [TRT] [I] Total Host Persistent Memory: 86864
[10/07/2022-18:07:49] [TRT] [I] Total Device Persistent Memory: 14077952
[10/07/2022-18:07:49] [TRT] [I] Total Scratch Memory: 4055040
[10/07/2022-18:07:49] [TRT] [I] [MemUsageStats] Peak memory usage of TRT CPU/GPU memory allocators: CPU 81 MiB, GPU 770 MiB
[10/07/2022-18:07:49] [TRT] [I] [BlockAssignment] Algorithm ShiftNTopDown took 7.37309ms to assign 4 blocks to 71 nodes requiring 66654720 bytes.
[10/07/2022-18:07:49] [TRT] [I] Total Activation Memory: 66654720
[10/07/2022-18:07:49] [TRT] [I] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +0, GPU +7, now: CPU 1596, GPU 15758 (MiB)
[10/07/2022-18:07:49] [TRT] [I] [MemUsageChange] Init cuDNN: CPU +1, GPU +8, now: CPU 1597, GPU 15766 (MiB)
[10/07/2022-18:07:49] [TRT] [I] [MemUsageChange] TensorRT-managed allocation in building engine: CPU +7, GPU +16, now: CPU 7, GPU 16 (MiB)
INFO:EngineBuilder:Serializing engine to file: /home/jetson/data/model_int8.engine

real	25m54.304s
user	12m59.576s
sys	2m19.060s

int8 infer

python trt_classificator_av.py --model=model_int8.engine --image=testimage.jpg
load: 2.4899113178253174 sec
1st frame: 0.08790969848632812 sec
infer: 0.08123373985290527 sec
[array([[0.01563849, 0.01595112, 0.0156964 , ..., 0.01456072, 0.01473866,
        0.01655393],
       [0.01571514, 0.01588619, 0.01548513, ..., 0.01446212, 0.01450986,
        0.01636523],
       [0.01548512, 0.01525605, 0.01514633, ..., 0.01449381, 0.01466152,
        0.01619074],
       ...,
       [0.01610926, 0.01537016, 0.0157151 , ..., 0.01567235, 0.01550561,
        0.01676121],
       [0.01601627, 0.0155768 , 0.01564961, ..., 0.01563967, 0.01556563,
        0.01645875],
       [0.01588311, 0.0155883 , 0.01569124, ..., 0.01570016, 0.01566373,
        0.01626299]], dtype=float32), array([[ 0.12145996, -4.7773438 ,  2.5253906 , -1.3378906 ],
       [ 1.4853516 , -4.1601562 ,  2.5488281 , -0.71191406],
       [ 2.4570312 , -3.6328125 ,  2.0429688 ,  0.7548828 ],
       ...,
       [-1.6835938 , -4.3554688 ,  2.71875   , -0.69628906],
       [-1.3671875 , -3.34375   ,  2.125     , -0.05023193],
       [-1.4833984 , -3.046875  ,  1.4638672 , -0.29077148]],
      dtype=float32)]

on Jetson AGX Xavier 32GB JetPack 4.6.1,
FP32:

load: 2.4099135398864746 sec
1st frame: 0.15322327613830566 sec
infer: 0.14696931838989258 sec

FP16:

load: 2.4407081604003906 sec
1st frame: 0.09991908073425293 sec
infer: 0.0901947021484375 sec

INT8:

load: 2.4899113178253174 sec
1st frame: 0.08790969848632812 sec
infer: 0.08123373985290527 sec

You can see that even with int8, the speed has not increased much.
I measured the preprosessing and infer separately.

trt_classificator_av.py (11.8 KB)

FP32:

-- image --
preprocess: 0.037741661071777344 sec
infer: 0.10097241401672363 sec

FP16:

-- image --
preprocess: 0.041993141174316406 sec
infer: 0.04722285270690918 sec

INT8:

-- image --
preprocess: 0.04466676712036133 sec
infer: 0.03199887275695801 sec

You can see that preprosessing is slow.
Comment out and execute processes that are not needed. (trt_classificator_av.py)

        #image = googlenet_preprocess(image)

INT8:

-- image --
preprocess: 0.013820648193359375 sec
infer: 0.03211164474487305 sec

Total: 0.045932292938232425 sec
This is 21.7 FPS.

Add

I have been trying to find out why the batch_size dimension is missing.
Perhaps the cause is this.

You can specify the input shape of your model in several different ways. For example by providing one of the following arguments to the first layer of your model:

batch_input_shape: A tuple where the first dimension is the batch size.
input_shape: A tuple that does not include the batch size, e.g., the batch size is assumed to be None or batch_size, if specified.
input_dim: A scalar indicating the dimension of the input.

Looking at the model, the input node is [H,W,C], but Unsqueeze immediately ExpandDim it to [1,H,W,C].

Therefore, if I set the input node to ExpandDim:0 instead of x_in:0, the input layer will have a batch dimension just like a normal model.
To use DeepStream, add the --inputs-as-nchw option when converting onnx.

frozen_graph.pb (input:NHWC) to onnx (input:NCHW)

time python -m tf2onnx.convert --input frozen_graph.pb --output model.onnx --opset 12 --inputs ExpandDims:0 --outputs decoder/mul_1:0,decoder/Softmax:0 --inputs-as-nchw ExpandDims:0

Now that I have an onnx model with input node [1,C,H,W], I can convert it to TensorRT.

time /usr/src/tensorrt/bin/trtexec --onnx=model.onnx --saveEngine=model_fp32.engine
time /usr/src/tensorrt/bin/trtexec --onnx=model.onnx --saveEngine=model_fp16.engine --fp16
time python ~/github/TensorRT/samples/python/tensorflow_object_detection_api/build_engine.py --onnx model.onnx --engine model_int8.engine --precision int8 --calib_input=val2017 --calib_cache model_int8.calib

Now that the dimensions of the input node are correct, I will also modify the inference code.
trt_classificator_av.py (11.3 KB)

### Infer
python trt_classificator_av.py --model=model_fp32.engine --image=testimage.jpg
python trt_classificator_av.py --model=model_fp16.engine --image=testimage.jpg
python trt_classificator_av.py --model=model_int8.engine --image=testimage.jpg

This engine can also be used for inference in deepstream.
(as network-type=100, output-tensor-meta=1)

1 Like

This topic was automatically closed 14 days after the last reply. New replies are no longer allowed.