My TF-TRT converted saved_model gets warning ans slow inferenece

Description

I finetuned Tensorflow Object Detection API CenterNet Resnet50 V1 FPN Keypoints 512x512 (models/tf2_detection_zoo.md at master · tensorflow/models · GitHub) and converted saved_model to a TF-TRT model. (FP16 precision)
But, converted model gets waning when inference and slow inference (a few sec/inference).
On the other hand, native saved_model is 0.03 sec/inference.

Warning types are two types.

2023-01-16 02:42:33.165965: W tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc:847] TF-TRT Warning: Running native segment forPartitionedCall/TRTEngineOp_000_005 due to failure in verifying input shapes: Incorrect batch dimension, for PartitionedCall/TRTEngineOp_000_005: [[0,1], [0,1]]
2023-01-16 02:42:34.218372: W tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc:5907] TF-TRT Warning: Validation failed for TensorRTInputPH_0 and input slot 0: Input tensor with shape [1,0,1] is an empty tensor, which is not supported by TRT
2023-01-16 02:42:34.218490: W tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc:1104] TF-TRT Warning: Engine creation for PartitionedCall/TRTEngineOp_000_002 failed. The native segment will be used instead. Reason: UNIMPLEMENTED: Validation failed for TensorRTInputPH_0 and input slot 0: Input tensor with shape [1,0,1] is an empty tensor, which is not supported by TRT
2023-01-10 10:37:03.250971: W tensorflow/compiler/tf2tensorrt/utils/trt_logger.cc:83] TF-TRT Warning: DefaultLogger TensorRT encountered issues when converting weights between types and that could affect accuracy.
2023-01-10 10:37:03.251011: W tensorflow/compiler/tf2tensorrt/utils/trt_logger.cc:83] TF-TRT Warning: DefaultLogger Check verbose logs for the list of affected weights.
2023-01-10 10:37:03.251019: W tensorflow/compiler/tf2tensorrt/utils/trt_logger.cc:83] TF-TRT Warning: DefaultLogger - 3 weights are affected by this issue: Detected finite FP32 values which would overflow in FP16 and converted them to the closest finite FP16 value.

Environment

TensorRT Version: 8.4.1
GPU Type: Tesla V100-PCIE-32GB
Nvidia Driver Version: 515.86.01
CUDA Version: 11.7
CUDNN Version:
Operating System + Version: Ubuntu 20.04
Python Version: 3.8.12
**TensorFlow Version **: tf-nightly 2.12.0.dev20230110

Steps To Reproduce

  1. Finetune CenterNet Resnet50 V1 FPN Keypoints 512x512 and convert pb file to saved_model.
  2. convert saved_model to a TF-TRT model by following script.
from tensorflow.python.compiler.tensorrt import trt_convert as trt
import numpy as np

num_runs = 1

SAVED_MODEL_DIR=f'./saved_model'

precision_mode = 'FP16'
conversion_params = trt.DEFAULT_TRT_CONVERSION_PARAMS
conversion_params = conversion_params._replace(precision_mode=precision_mode)

# Instantiate the TF-TRT converter
converter = trt.TrtGraphConverterV2(
    input_saved_model_dir=SAVED_MODEL_DIR,
    conversion_params=conversion_params
)

# Convert the model into TRT compatible segments
trt_func = converter.convert()
converter.summary()

def my_input_fn():
    for _ in range(num_runs):
        inp1 = np.random.normal(size=(1, 512, 512, 3)).astype(np.uint8)
        yield inp1
converter.build(input_fn=my_input_fn)

OUTPUT_SAVED_MODEL_DIR=f"./tftrt_saved_model_{precision_mode.lower()}"
converter.save(output_saved_model_dir=OUTPUT_SAVED_MODEL_DIR)

print(f'Done Converting to TF-TRT {precision_mode}')

This is convertion log.

$ python saved_model_to_tftrt.py 
2023-01-22 09:14:48.206978: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-01-22 09:14:48.339506: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-01-22 09:14:48.384884: E tensorflow/tsl/lib/monitoring/collection_registry.cc:81] Cannot register 2 metrics with the same name: /tensorflow/core/bfc_allocator_delay
2023-01-22 09:14:50.037924: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-01-22 09:14:50.949046: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1628] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 30941 MB memory:  -> device: 0, name: Tesla V100-PCIE-32GB, pci bus id: 0000:af:00.0, compute capability: 7.0
2023-01-22 09:15:04.607141: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 1
2023-01-22 09:15:04.607465: I tensorflow/core/grappler/clusters/single_machine.cc:358] Starting new session
2023-01-22 09:15:04.608681: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1628] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 30941 MB memory:  -> device: 0, name: Tesla V100-PCIE-32GB, pci bus id: 0000:af:00.0, compute capability: 7.0
2023-01-22 09:15:11.025318: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 1
2023-01-22 09:15:11.025509: I tensorflow/core/grappler/clusters/single_machine.cc:358] Starting new session
2023-01-22 09:15:11.026860: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1628] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 30941 MB memory:  -> device: 0, name: Tesla V100-PCIE-32GB, pci bus id: 0000:af:00.0, compute capability: 7.0
2023-01-22 09:15:12.593974: W tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc:186] Calibration with FP32 or FP16 is not implemented. Falling back to use_calibration = False.Note that the default value of use_calibration is True.
2023-01-22 09:15:12.979638: W tensorflow/compiler/tf2tensorrt/segment/segment.cc:962] 

################################################################################
TensorRT unsupported/non-converted OP Report:
	- Reshape -> 109x
	- Cast -> 92x
	- Pack -> 82x
	- StridedSlice -> 60x
	- Shape -> 39x
	- Tile -> 30x
	- Mul -> 29x
	- GatherNd -> 28x
	- Less -> 27x
	- ExpandDims -> 26x
	- AddV2 -> 22x
	- Sub -> 22x
	- Greater -> 21x
	- Unpack -> 19x
	- Select -> 15x
	- GatherV2 -> 15x
	- Fill -> 15x
	- Placeholder -> 13x
	- Identity -> 12x
	- Transpose -> 10x
	- GreaterEqual -> 10x
	- TensorScatterAdd -> 10x
	- EnsureShape -> 10x
	- Switch -> 8x
	- Merge -> 7x
	- LogicalAnd -> 6x
	- ConcatV2 -> 6x
	- Enter -> 6x
	- RealDiv -> 6x
	- NextIteration -> 6x
	- Where -> 5x
	- ArgMin -> 5x
	- Maximum -> 5x
	- Equal -> 5x
	- Round -> 4x
	- Pad -> 3x
	- Squeeze -> 2x
	- ResizeBilinear -> 2x
	- NoOp -> 2x
	- TensorListReserve -> 2x
	- TensorListSetItem -> 2x
	- TensorListStack -> 2x
	- Minimum -> 2x
	- Exit -> 2x
	- TensorListFromTensor -> 1x
	- TensorListGetItem -> 1x
	- LoopCond -> 1x
--------------------------------------------------------------------------------
	- Total nonconverted OPs: 807
	- Total nonconverted OP Types: 47
For more information see https://docs.nvidia.com/deeplearning/frameworks/tf-trt-user-guide/index.html#supported-ops.
################################################################################

2023-01-22 09:15:13.451224: W tensorflow/compiler/tf2tensorrt/segment/segment.cc:1239] A total of 25 segments with at least minimum_segment_size=3 nodes have been found. TF-TRT will only convert the 20 largest segments. You can change this behavior by modifying the environment variable TF_TRT_MAX_ALLOWED_ENGINES=20
2023-01-22 09:15:13.459988: I tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc:799] Number of TensorRT candidate segments: 20
2023-01-22 09:15:13.720193: I tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc:913] Replaced segment 0 consisting of 30 nodes by TRTEngineOp_000_000.
2023-01-22 09:15:13.720347: I tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc:913] Replaced segment 1 consisting of 5 nodes by TRTEngineOp_000_001.
2023-01-22 09:15:13.720412: I tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc:913] Replaced segment 2 consisting of 5 nodes by TRTEngineOp_000_002.
2023-01-22 09:15:13.720474: I tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc:913] Replaced segment 3 consisting of 5 nodes by TRTEngineOp_000_003.
2023-01-22 09:15:13.720535: I tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc:913] Replaced segment 4 consisting of 5 nodes by TRTEngineOp_000_004.
2023-01-22 09:15:13.720613: I tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc:913] Replaced segment 5 consisting of 5 nodes by TRTEngineOp_000_005.
2023-01-22 09:15:13.720671: I tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc:913] Replaced segment 6 consisting of 4 nodes by TRTEngineOp_000_006.
2023-01-22 09:15:13.720728: I tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc:913] Replaced segment 7 consisting of 4 nodes by TRTEngineOp_000_007.
2023-01-22 09:15:13.720782: I tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc:913] Replaced segment 8 consisting of 4 nodes by TRTEngineOp_000_008.
2023-01-22 09:15:13.720837: I tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc:913] Replaced segment 9 consisting of 4 nodes by TRTEngineOp_000_009.
2023-01-22 09:15:13.721089: I tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc:913] Replaced segment 10 consisting of 86 nodes by TRTEngineOp_000_010.
2023-01-22 09:15:13.721345: I tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc:913] Replaced segment 11 consisting of 39 nodes by TRTEngineOp_000_011.
2023-01-22 09:15:13.721534: I tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc:913] Replaced segment 12 consisting of 52 nodes by TRTEngineOp_000_012.
2023-01-22 09:15:13.721789: I tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc:913] Replaced segment 13 consisting of 794 nodes by TRTEngineOp_000_013.
2023-01-22 09:15:13.723540: I tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc:913] Replaced segment 14 consisting of 5 nodes by TRTEngineOp_000_014.
2023-01-22 09:15:13.723625: I tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc:913] Replaced segment 15 consisting of 7 nodes by TRTEngineOp_000_015.
2023-01-22 09:15:13.723706: I tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc:913] Replaced segment 16 consisting of 16 nodes by TRTEngineOp_000_016.
2023-01-22 09:15:13.723841: I tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc:913] Replaced segment 17 consisting of 31 nodes by TRTEngineOp_000_017.
2023-01-22 09:15:13.723937: I tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc:913] Replaced segment 18 consisting of 6 nodes by TRTEngineOp_000_018.
2023-01-22 09:15:13.724088: I tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc:913] Replaced segment 19 consisting of 21 nodes by TRTEngineOp_000_019.
2023-01-22 09:15:16.326664: I tensorflow/core/common_runtime/executor.cc:1195] [/device:CPU:0] Executor start aborting: INVALID_ARGUMENT: You must feed a value for placeholder tensor 'unused_control_flow_input_8' with dtype int32 and shape [1]
	 [[{{node unused_control_flow_input_8}}]]
2023-01-22 09:15:16.338144: I tensorflow/core/common_runtime/executor.cc:1195] [/device:CPU:0] Executor start aborting: INVALID_ARGUMENT: You must feed a value for placeholder tensor 'unused_control_flow_input_14' with dtype int32 and shape [1]
	 [[{{node unused_control_flow_input_14}}]]
2023-01-22 09:15:16.340213: I tensorflow/core/common_runtime/executor.cc:1195] [/device:CPU:0] Executor start aborting: INVALID_ARGUMENT: You must feed a value for placeholder tensor 'unused_control_flow_input_16' with dtype int32 and shape [1]
	 [[{{node unused_control_flow_input_16}}]]
TRTEngineOP Name                 Device        # Nodes # Inputs      # Outputs     Input DTypes       Output Dtypes      Input Shapes       Output Shapes     
================================================================================================================================================================

----------------------------------------

TRTEngineOp_000_000              device:GPU:0  34      6             3             ['float16', 'f ... ['float32', 'f ... [[1, 1, 4], [1 ... [[1, 10, 4], [ ...

	- AddV2: 5x
	- Cast: 4x
	- Const: 5x
	- Maximum: 6x
	- Minimum: 5x
	- Mul: 3x
	- Pack: 1x
	- RealDiv: 1x
	- Sub: 2x
	- Unpack: 2x

----------------------------------------

TRTEngineOp_000_001              device:GPU:0  5       4             1             ['float16', 'f ... ['float32']        [[1, -1, 1], [ ... [[1, -1, -1, 2]]  

	- AddV2: 2x
	- Cast: 2x
	- Pack: 1x

----------------------------------------

TRTEngineOp_000_002              device:GPU:0  5       4             1             ['float16', 'f ... ['float32']        [[1, -1, 1], [ ... [[1, -1, -1, 2]]  

	- AddV2: 2x
	- Cast: 2x
	- Pack: 1x

----------------------------------------

TRTEngineOp_000_003              device:GPU:0  5       4             1             ['float16', 'f ... ['float32']        [[1, -1, 1], [ ... [[1, -1, -1, 2]]  

	- AddV2: 2x
	- Cast: 2x
	- Pack: 1x

----------------------------------------

TRTEngineOp_000_004              device:GPU:0  5       4             1             ['float16', 'f ... ['float32']        [[1, -1, 1], [ ... [[1, -1, -1, 2]]  

	- AddV2: 2x
	- Cast: 2x
	- Pack: 1x

----------------------------------------

TRTEngineOp_000_005              device:GPU:0  5       4             1             ['float16', 'f ... ['float32']        [[1, -1, 1], [ ... [[1, -1, -1, 2]]  

	- AddV2: 2x
	- Cast: 2x
	- Pack: 1x

----------------------------------------

TRTEngineOp_000_006              device:GPU:0  5       2             1             ['int32', 'int32'] ['int32']          [[-1, 1], [-1, 1]] [[-1, 2, 2]]      

	- Const: 2x
	- Mul: 2x
	- Pack: 1x

----------------------------------------

TRTEngineOp_000_007              device:GPU:0  5       2             1             ['int32', 'int32'] ['int32']          [[-1, 1], [-1, 1]] [[-1, 3, 2]]      

	- Const: 2x
	- Mul: 2x
	- Pack: 1x

----------------------------------------

TRTEngineOp_000_008              device:GPU:0  5       2             1             ['int32', 'int32'] ['int32']          [[-1, 1], [-1, 1]] [[-1, 3, 2]]      

	- Const: 2x
	- Mul: 2x
	- Pack: 1x

----------------------------------------

TRTEngineOp_000_009              device:GPU:0  5       2             1             ['int32', 'int32'] ['int32']          [[-1, 1], [-1, 1]] [[-1, 3, 2]]      

	- Const: 2x
	- Mul: 2x
	- Pack: 1x

----------------------------------------

TRTEngineOp_000_010              device:GPU:0  91      12            31            ['float32', 'i ... ['float32', 'f ... [[1, 3, 100],  ... [[1, 100, 3],  ...

	- AddV2: 5x
	- Cast: 1x
	- Const: 14x
	- FloorDiv: 12x
	- Mul: 18x
	- Pack: 1x
	- Reshape: 16x
	- StridedSlice: 1x
	- Sub: 12x
	- TopKV2: 1x
	- Transpose: 10x

----------------------------------------

TRTEngineOp_000_011              device:GPU:0  29      10            10            ['float16', 'f ... ['float32', 'i ... [[1, 128, 128, ... [[1, 3, 100],  ...

	- Cast: 5x
	- Const: 4x
	- Mul: 5x
	- Reshape: 5x
	- TopKV2: 5x
	- Transpose: 5x

----------------------------------------

TRTEngineOp_000_012              device:GPU:0  54      15            10            ['float16', 'f ... ['float32', 'f ... [[1, 300], [1, ... [[1, 3, 100, 2 ...

	- AddV2: 10x
	- Cast: 10x
	- Const: 4x
	- ExpandDims: 5x
	- Pack: 5x
	- Reshape: 5x
	- Transpose: 10x
	- Unpack: 5x

----------------------------------------

TRTEngineOp_000_013              device:GPU:0  790     3             24            ['float32', 'f ... ['float32', 'f ... [[1, 128, 1],  ... [[1, 128, 128, ...

	- Abs: 6x
	- AddV2: 19x
	- BiasAdd: 88x
	- Const: 413x
	- Conv2D: 96x
	- ExpandDims: 1x
	- FusedBatchNormV3: 56x
	- MaxPool: 7x
	- Mul: 2x
	- Pad: 2x
	- Relu: 70x
	- ResizeNearestNeighbor: 3x
	- Sigmoid: 6x
	- StridedSlice: 15x
	- Sub: 6x

----------------------------------------

TRTEngineOp_000_014              device:GPU:0  6       2             2             ['float32', 'f ... ['float32', 'f ... [[1, -1, 4], [ ... [[1, -1, 2, 4] ...

	- Const: 2x
	- ExpandDims: 2x
	- Tile: 2x

----------------------------------------

TRTEngineOp_000_015              device:GPU:0  8       3             3             ['float32', 'f ... ['float32', 'f ... [[1, -1, 4], [ ... [[1, -1, 3, 4] ...

	- Const: 2x
	- ExpandDims: 3x
	- Tile: 3x

----------------------------------------

TRTEngineOp_000_016              device:GPU:0  17      5             5             ['int32', 'int ... ['int32', 'int ... [[1, 100, 3],  ... [[1, 100, 3],  ...

	- Const: 2x
	- ExpandDims: 5x
	- Sum: 5x
	- Tile: 5x

----------------------------------------

TRTEngineOp_000_017              device:GPU:0  33      10            5             ['float32', 'f ... ['float32', 'f ... [[1, -1, 3, 2] ... [[1, -1, 100,  ...

	- Const: 3x
	- ExpandDims: 5x
	- Mul: 5x
	- Sqrt: 5x
	- Sub: 5x
	- Sum: 5x
	- Tile: 5x

----------------------------------------

TRTEngineOp_000_018              device:GPU:0  6       5             5             ['float32', 'f ... ['float32', 'f ... [[1, -1, 100,  ... [[1, -1, 100,  ...

	- Const: 1x
	- Mul: 5x

----------------------------------------

TRTEngineOp_000_019              device:GPU:0  21      20            5             ['float32', 'f ... ['float32', 'f ... [[1, -1, 3], [ ... [[1, -1, 3], [ ...

	- Const: 1x
	- Maximum: 5x
	- Mul: 5x
	- Sub: 10x

================================================================================================================================================================
[*] Total number of TensorRT engines: 20
[*] % of OPs Converted: 52.99% [1134/2140]

2023-01-22 09:15:20.108012: I tensorflow/compiler/tf2tensorrt/common/utils.cc:104] Linked TensorRT version: 8.4.3
2023-01-22 09:15:20.108405: I tensorflow/compiler/tf2tensorrt/common/utils.cc:106] Loaded TensorRT version: 8.4.3
2023-01-22 09:15:23.640329: I tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc:1330] [TF-TRT] Sparse compute capability: enabled.
2023-01-22 09:20:03.394695: W tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc:847] TF-TRT Warning: Running native segment forTRTEngineOp_000_007 due to failure in verifying input shapes: Incorrect batch dimension, for TRTEngineOp_000_007: [[0,1], [0,1]]
2023-01-22 09:20:03.394812: W tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc:847] TF-TRT Warning: Running native segment forTRTEngineOp_000_008 due to failure in verifying input shapes: Incorrect batch dimension, for TRTEngineOp_000_008: [[0,1], [0,1]]
2023-01-22 09:20:03.394919: W tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc:847] TF-TRT Warning: Running native segment forTRTEngineOp_000_009 due to failure in verifying input shapes: Incorrect batch dimension, for TRTEngineOp_000_009: [[0,1], [0,1]]
2023-01-22 09:20:04.433500: W tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc:5927] TF-TRT Warning: Validation failed for TensorRTInputPH_0 and input slot 0: Input tensor with shape [1,0,1] is an empty tensor, which is not supported by TRT
2023-01-22 09:20:04.433578: W tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc:1104] TF-TRT Warning: Engine creation for TRTEngineOp_000_002 failed. The native segment will be used instead. Reason: UNIMPLEMENTED: Validation failed for TensorRTInputPH_0 and input slot 0: Input tensor with shape [1,0,1] is an empty tensor, which is not supported by TRT
2023-01-22 09:20:04.433656: W tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc:937] TF-TRT Warning: Engine retrieval for input shapes: [[1,0,1], [1,0,1], [1,0,3], [1,0,3]] failed. Running native segment for TRTEngineOp_000_002
2023-01-22 09:20:04.434333: W tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc:5927] TF-TRT Warning: Validation failed for TensorRTInputPH_0 and input slot 0: Input tensor with shape [1,0,1] is an empty tensor, which is not supported by TRT
2023-01-22 09:20:04.434428: W tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc:1104] TF-TRT Warning: Engine creation for TRTEngineOp_000_003 failed. The native segment will be used instead. Reason: UNIMPLEMENTED: Validation failed for TensorRTInputPH_0 and input slot 0: Input tensor with shape [1,0,1] is an empty tensor, which is not supported by TRT
2023-01-22 09:20:04.434532: W tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc:937] TF-TRT Warning: Engine retrieval for input shapes: [[1,0,1], [1,0,1], [1,0,3], [1,0,3]] failed. Running native segment for TRTEngineOp_000_003
2023-01-22 09:20:04.434842: W tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc:5927] TF-TRT Warning: Validation failed for TensorRTInputPH_0 and input slot 0: Input tensor with shape [1,0,1] is an empty tensor, which is not supported by TRT
2023-01-22 09:20:04.434927: W tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc:1104] TF-TRT Warning: Engine creation for TRTEngineOp_000_004 failed. The native segment will be used instead. Reason: UNIMPLEMENTED: Validation failed for TensorRTInputPH_0 and input slot 0: Input tensor with shape [1,0,1] is an empty tensor, which is not supported by TRT
2023-01-22 09:20:04.435017: W tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc:937] TF-TRT Warning: Engine retrieval for input shapes: [[1,0,1], [1,0,1], [1,0,3], [1,0,3]] failed. Running native segment for TRTEngineOp_000_004
2023-01-22 09:20:04.435190: W tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc:5927] TF-TRT Warning: Validation failed for TensorRTInputPH_0 and input slot 0: Input tensor with shape [1,0,1] is an empty tensor, which is not supported by TRT
2023-01-22 09:20:04.435338: W tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc:1104] TF-TRT Warning: Engine creation for TRTEngineOp_000_005 failed. The native segment will be used instead. Reason: UNIMPLEMENTED: Validation failed for TensorRTInputPH_0 and input slot 0: Input tensor with shape [1,0,1] is an empty tensor, which is not supported by TRT
2023-01-22 09:20:04.435474: W tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc:937] TF-TRT Warning: Engine retrieval for input shapes: [[1,0,1], [1,0,1], [1,0,2], [1,0,2]] failed. Running native segment for TRTEngineOp_000_005
2023-01-22 09:20:07.408563: W tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc:5927] TF-TRT Warning: Validation failed for TensorRTInputPH_0 and input slot 0: Input tensor with shape [1,0,3,2] is an empty tensor, which is not supported by TRT
2023-01-22 09:20:07.408636: W tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc:1104] TF-TRT Warning: Engine creation for TRTEngineOp_000_017 failed. The native segment will be used instead. Reason: UNIMPLEMENTED: Validation failed for TensorRTInputPH_0 and input slot 0: Input tensor with shape [1,0,3,2] is an empty tensor, which is not supported by TRT
2023-01-22 09:20:07.408681: W tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc:937] TF-TRT Warning: Engine retrieval for input shapes: [[1,0,3,2], [1,0,3,2], [1,0,100,3,2], [1,0,2,2], [1,0,100,2,2], [1,10,2,2], [1,10,100,2,2], [1,0,3,2], [1,0,100,3,2], [1,0,100,3,2]] failed. Running native segment for TRTEngineOp_000_017
2023-01-22 09:20:07.424732: W tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc:5927] TF-TRT Warning: Validation failed for TensorRTInputPH_0 and input slot 0: Input tensor with shape [1,0,100,3] is an empty tensor, which is not supported by TRT
2023-01-22 09:21:05.886347: W tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc:5927] TF-TRT Warning: Validation failed for TensorRTInputPH_0 and input slot 0: Input tensor with shape [1,0,4] is an empty tensor, which is not supported by TRT
2023-01-22 09:21:05.887228: W tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc:5927] TF-TRT Warning: Validation failed for TensorRTInputPH_0 and input slot 0: Input tensor with shape [1,0,4] is an empty tensor, which is not supported by TRT
2023-01-22 09:21:06.891027: W tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc:5927] TF-TRT Warning: Validation failed for TensorRTInputPH_0 and input slot 0: Input tensor with shape [1,0,3] is an empty tensor, which is not supported by TRT
2023-01-22 09:21:06.947776: W tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc:847] TF-TRT Warning: Running native segment forTRTEngineOp_000_009 due to failure in verifying input shapes: Incorrect batch dimension, for TRTEngineOp_000_009: [[0,1], [0,1]]
2023-01-22 09:21:06.948063: W tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc:847] TF-TRT Warning: Running native segment forTRTEngineOp_000_008 due to failure in verifying input shapes: Incorrect batch dimension, for TRTEngineOp_000_008: [[0,1], [0,1]]
Done Converting to TF-TRT FP16
  1. Inference image by convertied saved_model by following script.
    This script runs successfly and gets ineference results, but I get warning and slow inerence.
import tensorflow as tf
from tensorflow.python.saved_model import signature_constants
import cv2
import numpy as np

image_file_path = '<image file path>'
MODEL_PATH = '<model path>'
m = tf.saved_model.load(MODEL_PATH)
ff = m.signatures[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]


img = cv2.imread(image_file_path).copy()
img = cv2.resize(img, dsize=(512,512))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

input_tensor = img[tf.newaxis, ...].astype(np.uint8)
input_tensor = tf.convert_to_tensor(input_tensor)

results = ff(input_tensor)
print(results)

I’m struggling for solving this warning for over 2 weeks. Thanks.

Hi,

Could you please try on the latest TensorRT version 8.5.2.
You can also use latest NGC container TensorFlow Release Notes :: NVIDIA Deep Learning Frameworks Documentation.

If you still face this issue, could you please share with us sample image and model to try from our end

Thank you.

Hi,
We recommend you to check the below samples links in case of tf-trt integration issues.

If issue persist, We recommend you to reach out to Tensorflow forum.
Thanks!

Thank you for replying.

I tried TensorRT 8.5 (uing container nvcr.io/nvidia/tensorflow:22.12-tf2-py3), but I get same warning and slow inference.

Here is my TF-TRT model and
models

images

Hi,

Sorry for the delayed response.
We couldn’t reproduce the similar warnings (for CenterNet Resnet50 V1 FPN Keypoints 512x512 model) on Tesla V100 GPUs.
Could you please try on the latest Tensorflow NGC container nvcr.io/nvidia/tensorflow:23.01-tf2-py3

Please find the output logs.
tf-trt_log.txt (12.7 KB)

If you still face this issue, please share with us fine-tuned model (without converting to tf-trt).

Thank you.

Thanks for your investigation.
I got it. There might have been a mistake.

I’ ll try again. Thank you.

This topic was automatically closed 14 days after the last reply. New replies are no longer allowed.