Calibration failed: INTERNAL: Failed to build TensorRT engine (INT8 precision mode) in Jetson Xavier NX (16GB)

Dear All,

I am trying to optimize a custom TensorFlow-Keras model. I am able to save the model and build TF-TRT engine with precision mode FP32. Also, I am able to build TR-TRT engine with precision model FP16; however, the througput is much lesser than FP32 engine. Lastly, with precision model INT8 enable, I am not able to generate the TF-TRT engine. Therefore, I request you all to help me resolving the issue.

import os
import time
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
import cv2
from tensorflow.python.compiler.tensorrt import trt_convert as trt
from tensorflow.python.saved_model import tag_constants
from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.keras.models import load_model
import tensorflow.keras.backend as K

from tensorflow.python.client import device_lib

def check_tensor_core_gpu_present():
    local_device_protos = device_lib.list_local_devices()
    for line in local_device_protos:
        if "compute capability" in str(line):
            compute_capability = float(line.physical_device_desc.split("compute capability: ")[-1])
            if compute_capability>=7.0:
                return True
    
print("Tensor Core GPU Present:", check_tensor_core_gpu_present())
tensor_core_gpu = check_tensor_core_gpu_present()


model = tf.keras.models.load_model('custom_saved_model')





batch_size = 6
batched_input = np.zeros((batch_size, 96, 96, 1), dtype=np.float32)

for i in range(batch_size):
  img_path = '/home/nvidia/Documents/data/img%d.jpg' % (i % 4)
  img = cv2.imread(img_path)
  img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  img = cv2.resize(img, (96,96))
  img = img.astype("float")/255.0
  img -= np.array([0.48112664], dtype= "float32")
  img /= np.array([0.24075599], dtype = "float32") + K.epsilon()
  img = img_to_array(img)
  x = np.expand_dims(img, axis=0)
  batched_input[i, :] = x
batched_input = tf.constant(batched_input)
print('batched_input shape: ', batched_input.shape)

```def calibration_input_fn():
    yield (batched_input, )

print('Converting to TF-TRT INT8...')

converter = trt.TrtGraphConverterV2(input_saved_model_dir='custom_saved_model',
                                   precision_mode=trt.TrtPrecisionMode.INT8,
                                    max_workspace_size_bytes=8000000000)

converter.convert(calibration_input_fn=calibration_input_fn)
converter.save(output_saved_model_dir='custom_saved_model_TFTRT_INT8')
print('Done Converting to TF-TRT INT8')

Below is the terminal output

2023-03-20 11:33:51.922640: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1019] ARM64 does not support NUMA - returning NUMA node zero
2023-03-20 11:33:51.937119: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1019] ARM64 does not support NUMA - returning NUMA node zero
2023-03-20 11:33:51.937509: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1019] ARM64 does not support NUMA - returning NUMA node zero
2023-03-20 11:33:54.205399: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1019] ARM64 does not support NUMA - returning NUMA node zero
2023-03-20 11:33:54.205778: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1019] ARM64 does not support NUMA - returning NUMA node zero
2023-03-20 11:33:54.206055: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1019] ARM64 does not support NUMA - returning NUMA node zero
2023-03-20 11:33:54.206251: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /device:GPU:0 with 9717 MB memory: → device: 0, name: Xavier, pci bus id: 0000:00:00.0, compute capability: 7.2
Tensor Core GPU Present: True
2023-03-20 11:33:54.208771: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1019] ARM64 does not support NUMA - returning NUMA node zero
2023-03-20 11:33:54.209031: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1019] ARM64 does not support NUMA - returning NUMA node zero
2023-03-20 11:33:54.209211: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1019] ARM64 does not support NUMA - returning NUMA node zero
2023-03-20 11:33:54.209469: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1019] ARM64 does not support NUMA - returning NUMA node zero
2023-03-20 11:33:54.209716: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1019] ARM64 does not support NUMA - returning NUMA node zero
2023-03-20 11:33:54.209817: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /device:GPU:0 with 9717 MB memory: → device: 0, name: Xavier, pci bus id: 0000:00:00.0, compute capability: 7.2
img_path
img_path
img_path
img_path
img_path
img_path
img_path
2023-03-20 11:33:58.710686: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1019] ARM64 does not support NUMA - returning NUMA node zero
2023-03-20 11:33:58.711135: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1019] ARM64 does not support NUMA - returning NUMA node zero
2023-03-20 11:33:58.711380: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1019] ARM64 does not support NUMA - returning NUMA node zero
2023-03-20 11:33:58.713078: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1019] ARM64 does not support NUMA - returning NUMA node zero
2023-03-20 11:33:58.713361: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1019] ARM64 does not support NUMA - returning NUMA node zero
2023-03-20 11:33:58.713595: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1019] ARM64 does not support NUMA - returning NUMA node zero
2023-03-20 11:33:58.713869: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1019] ARM64 does not support NUMA - returning NUMA node zero
2023-03-20 11:33:58.714077: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1019] ARM64 does not support NUMA - returning NUMA node zero
2023-03-20 11:33:58.714208: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 9717 MB memory: → device: 0, name: Xavier, pci bus id: 0000:00:00.0, compute capability: 7.2
batched_input shape: (6, 96, 96, 1)
Converting to TF-TRT INT8…
2023-03-20 11:34:46.283157: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1019] ARM64 does not support NUMA - returning NUMA node zero
2023-03-20 11:34:46.283412: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0
2023-03-20 11:34:46.284078: I tensorflow/core/grappler/clusters/single_machine.cc:358] Starting new session
2023-03-20 11:34:46.285068: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1019] ARM64 does not support NUMA - returning NUMA node zero
2023-03-20 11:34:46.285423: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1019] ARM64 does not support NUMA - returning NUMA node zero
2023-03-20 11:34:46.285647: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1019] ARM64 does not support NUMA - returning NUMA node zero
2023-03-20 11:34:46.285989: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1019] ARM64 does not support NUMA - returning NUMA node zero
2023-03-20 11:34:46.286203: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1019] ARM64 does not support NUMA - returning NUMA node zero
2023-03-20 11:34:46.286343: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 9717 MB memory: → device: 0, name: Xavier, pci bus id: 0000:00:00.0, compute capability: 7.2
2023-03-20 11:34:46.431182: I tensorflow/core/grappler/optimizers/meta_optimizer.cc:1176] Optimization results for grappler item: graph_to_optimize
function_optimizer: Graph size after: 296 nodes (224), 436 edges (364), time = 24.511ms.
function_optimizer: function_optimizer did nothing. time = 0.405ms.

2023-03-20 11:34:47.856416: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1019] ARM64 does not support NUMA - returning NUMA node zero
2023-03-20 11:34:47.856673: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0
2023-03-20 11:34:47.856956: I tensorflow/core/grappler/clusters/single_machine.cc:358] Starting new session
2023-03-20 11:34:47.857896: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1019] ARM64 does not support NUMA - returning NUMA node zero
2023-03-20 11:34:47.858197: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1019] ARM64 does not support NUMA - returning NUMA node zero
2023-03-20 11:34:47.858410: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1019] ARM64 does not support NUMA - returning NUMA node zero
2023-03-20 11:34:47.858901: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1019] ARM64 does not support NUMA - returning NUMA node zero
2023-03-20 11:34:47.859212: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1019] ARM64 does not support NUMA - returning NUMA node zero
2023-03-20 11:34:47.859346: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 9717 MB memory: → device: 0, name: Xavier, pci bus id: 0000:00:00.0, compute capability: 7.2
2023-03-20 11:34:48.261964: I tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc:408] [TF-TRT] not using explicit QDQ mode
2023-03-20 11:34:48.291207: W tensorflow/compiler/tf2tensorrt/segment/segment.cc:884]

################################################################################
TensorRT unsupported/non-converted OP Report:
- NoOp → 5x
- Identity → 1x
- Placeholder → 1x

- Total nonconverted OPs: 7
- Total nonconverted OP Types: 3

For more information see Accelerating Inference in TensorFlow with TensorRT User Guide - NVIDIA Docs.
################################################################################

2023-03-20 11:34:48.303032: I tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc:806] Number of TensorRT candidate segments: 1
2023-03-20 11:34:48.332258: I tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc:919] Replaced segment 0 consisting of 81 nodes by TRTEngineOp_0_0.
2023-03-20 11:34:48.467107: I tensorflow/core/grappler/optimizers/meta_optimizer.cc:1176] Optimization results for grappler item: tf_graph
constant_folding: Graph size after: 152 nodes (-144), 282 edges (-154), time = 149.346ms.
layout: Graph size after: 156 nodes (4), 286 edges (4), time = 58.385ms.
constant_folding: Graph size after: 154 nodes (-2), 284 edges (-2), time = 19.569ms.
TensorRTOptimizer: Graph size after: 74 nodes (-80), 141 edges (-143), time = 80.546ms.
constant_folding: Graph size after: 74 nodes (0), 141 edges (0), time = 13.409ms.
Optimization results for grappler item: TRTEngineOp_0_0_native_segment
constant_folding: Graph size after: 143 nodes (0), 145 edges (0), time = 15.653ms.
layout: Graph size after: 143 nodes (0), 145 edges (0), time = 21.621ms.
constant_folding: Graph size after: 143 nodes (0), 145 edges (0), time = 15.971ms.
TensorRTOptimizer: Graph size after: 143 nodes (0), 145 edges (0), time = 2.124ms.
constant_folding: Graph size after: 143 nodes (0), 145 edges (0), time = 16.081ms.

2023-03-20 11:36:48.315051: I tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc:444] TRTEngineOp not using explicit QDQ
2023-03-20 11:36:48.328994: I tensorflow/compiler/tf2tensorrt/common/utils.cc:94] Linked TensorRT version: 8.2.1
2023-03-20 11:36:48.337624: I tensorflow/compiler/tf2tensorrt/common/utils.cc:96] Loaded TensorRT version: 8.2.1
2023-03-20 11:36:51.421933: I tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc:1509] [TF-TRT] Sparse compute capability is enabled.
2023-03-20 11:36:51.425016: W tensorflow/compiler/tf2tensorrt/utils/trt_logger.cc:36] TF-TRT Warning: DefaultLogger It is suggested to disable layer timing cache while using AlgorithmSelector. Please refer to the developer guide in Developer Guide :: NVIDIA Deep Learning TensorRT Documentation.
2023-03-20 11:36:56.418784: E tensorflow/compiler/tf2tensorrt/utils/trt_logger.cc:40] DefaultLogger 2: [utils.cpp::checkMemLimit::380] Error Code 2: Internal Error (Assertion upperBound != 0 failed. Unknown embedded device detected. Please update the table with the entry: {{1794, 6, 16}, 12653},)
2023-03-20 11:36:56.572008: E tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc:1327] Calibration failed: INTERNAL: Failed to build TensorRT engine
Traceback (most recent call last):
File “TRT_saved_model.py”, line 171, in
converter.convert(calibration_input_fn=calibration_input_fn)
File “/usr/local/lib/python3.6/dist-packages/tensorflow/python/compiler/tensorrt/trt_convert.py”, line 1259, in convert
self._converted_func(*map(ops.convert_to_tensor, inp))
File “/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py”, line 1707, in call
return self._call_impl(args, kwargs)
File “/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/wrap_function.py”, line 249, in _call_impl
args, kwargs, cancellation_manager)
File “/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py”, line 1725, in _call_impl
return self._call_with_flat_signature(args, kwargs, cancellation_manager)
File “/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py”, line 1774, in _call_with_flat_signature
return self._call_flat(args, self.captured_inputs, cancellation_manager)
File “/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py”, line 1960, in _call_flat
ctx, args, cancellation_manager=cancellation_manager))
File “/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py”, line 603, in call
ctx=ctx)
File “/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py”, line 59, in quick_execute
inputs, attrs, num_outputs)
tensorflow.python.framework.errors_impl.InternalError: Failed to feed calibration data
[[node TRTEngineOp_0_0
(defined at TRT_saved_model.py:171)
]] [Op:__inference_pruned_18613]

Errors may have originated from an input operation.
Input Source operations connected to node TRTEngineOp_0_0:
In[0] input_1:
In[1] StatefulPartitionedCall/NoOp:

Operation defined at: (most recent call last)

File “TRT_saved_model.py”, line 171, in
converter.convert(calibration_input_fn=calibration_input_fn)

Board: NVIDIA Jetson Xavier NX (16GB)
TensorFlow: 2.7.0
JetPack: 4.6.1
TensorRT: 8.2.1

Hi,

Is pure TensorRT an option for you?
On Jetson, it’s recommended to use pure TensorRT instead of TF-TRT for performance.

Thanks.

@AastaLLL Thanks for the quick reply. What do you mean by pure TensorRT? You mean to use TensorRT C++ API? Can you share some resources related to converting Tensorflow-Keras model using pure TensorRT.

Hi,

TF-TRT is a TensorRT plugin that is integrated into the TensorFlow library.
So you will need to load the whole TensorFlow library even though the inference uses TensorRT API only.

Pure TensorRT indicates running the model with the TensorRT library directly.
This can be done by converting the model into an ONNX format and then feeding it to the TensorRT.

TensorRT supports both C++ and python. You can find some samples of the TensorFlow model below:

Thanks.

@AastaLLL Thanks for the information. I have installed TensorFlow following the steps suggested on the official site of the NVIDIA. https://docs.nvidia.com/deeplearning/frameworks/install-tf-jetson-platform/index.html. Can you please suggest steps to remove the TensorRT that installed with TensorFlow and install native TensorRT. By the way, the FP16 conversion works perfectly on Jetson Nano (2GB) and as expected gives better throughput than FP32. I switched to Jetson Xavier as Jetson Nano does not have Tensor cores and hence it does not support INT8 conversion. So, I feel, like not supporting INT8 and lower throughout of FP16 mode in Jetson Xavier is little unjustice to the device compared to way cheaper Jetson Nano (2GB). What are your views on this??

Hi,

TensorRT is installed by default from JetPack.

Since INT8 is an integer format, some extra format conversion (quantization) is required.
In some cases, the overhead might slow down the inference.

It’s recommended to try our TensorRT since it has optimized for Jetson hardware.

This workflow is easy.
First, please convert your model into ONNX format, usually, this can be done via tf2onnx.
Then infer it with the TensorRT binary.

$ /usr/src/tensorrt/bin/trtexec --onnx=[file] --fp16
$ /usr/src/tensorrt/bin/trtexec --onnx=[file] --int8

Thanks.

After executing command “/usr/src/tensorrt/bin/trtexec --onnx=[file] --fp16” I got the below error

[03/22/2023-14:32:30] [I] [TRT] [GpuLayer] copied_squeeze_after_StatefulPartitionedCall/model/output_model1/BiasAdd
[03/22/2023-14:32:30] [I] [TRT] [GpuLayer] StatefulPartitionedCall/model/activation/Softmax
[03/22/2023-14:32:31] [I] [TRT] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +226, GPU +217, now: CPU 719, GPU 4985 (MiB)
[03/22/2023-14:32:32] [I] [TRT] [MemUsageChange] Init cuDNN: CPU +307, GPU +395, now: CPU 1026, GPU 5380 (MiB)
[03/22/2023-14:32:32] [I] [TRT] Local timing cache in use. Profiling results in this builder pass will not be stored.
[03/22/2023-14:32:32] [E] Error[2]: [utils.cpp::checkMemLimit::380] Error Code 2: Internal Error (Assertion upperBound != 0 failed. Unknown embedded device detected. Please update the table with the entry: {{1794, 6, 16}, 12653},)
[03/22/2023-14:32:32] [E] Error[2]: [builder.cpp::buildSerializedNetwork::609] Error Code 2: Internal Error (Assertion enginePtr != nullptr failed. )
[03/22/2023-14:32:32] [E] Engine could not be created from network
[03/22/2023-14:32:32] [E] Building engine failed
[03/22/2023-14:32:32] [E] Failed to create engine from model.
[03/22/2023-14:32:32] [E] Engine set up failed
&&&& FAILED TensorRT.trtexec [TensorRT v8201] # /usr/src/tensorrt/bin/trtexec --onnx=/home/nvidia/Documents/fernet_raf_gray.onnx --fp16

Environment:
Board: NVIDIA Jetson Xavier NX (16GB)
TensorFlow: 2.7.0
JetPack: 4.6.1
TensorRT: 8.2.1
CUDA: 10.2

Please help me resolving this issue.

Hi,

Please upgrade to JetPack 4.6.3 and try it again.

Thanks.

@AastaLLL Thank you. Flashing with JetPack 4.6.3 solved the issue. Also, the TF-TRT conversion is working well in all precision modes (FP32, FP16, and INT8). However, I still have issues in flashing the JetPack using external storage devices. As of now, I am closing this issue. If required, I will open the issue related to flashing JetPack using external SD card on the Jetson Xavier NX module in a separate thread.

Thank You for the assistance and your time.

This topic was automatically closed 14 days after the last reply. New replies are no longer allowed.