TensorRT model always return NaN output

Description

I’m trying to run superglue with TensorRT,it returns dynamic shape output.It’s run without any error but the result is always NaN.When i set my output shape upper than expected shape i got NaN for that expected shape and 0.0 for others.

Environment

TensorRT Version: 8.6.2.3
GPU Type: Nvidia Jetson Orin nx (16Gb ram)
Nvidia Driver Version: Jetpack 6.0 DP
CUDA Version: 12.02.140
CUDNN Version: 8.9.4.25
Operating System + Version: Ubuntu 22.04 LTS
Python Version (if applicable): 3.10.12

Relevant Files

Superglue pytorch implementation: GitHub - magicleap/SuperGluePretrainedNetwork: SuperGlue: Learning Feature Matching with Graph Neural Networks (CVPR 2020, Oral)
convert_to_onnx: SuperPoint-SuperGlue-TensorRT/convert2onnx/convert_superglue_to_onnx.py at main · yuefanhao/SuperPoint-SuperGlue-TensorRT · GitHub
build_engine.py:

import tensorrt as trt

# Initialize TensorRT logger and builder
TRT_LOGGER = trt.Logger(trt.Logger.INFO)
builder = trt.Builder(TRT_LOGGER)
config = builder.create_builder_config()


# Set cache
cache = config.create_timing_cache(b"")
config.set_timing_cache(cache, ignore_mismatch=False)


flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
network = builder.create_network(flag)
parser = trt.OnnxParser(network, TRT_LOGGER)

path_onnx_model = "/home/jetson/a/TensorRT/trt/superglue_outdoor_sim_int32.onnx"

with open(path_onnx_model, "rb") as f:
    if not parser.parse(f.read()):
        print(f"ERROR: Failed to parse the ONNX file {path_onnx_model}")
        for error in range(parser.num_errors):
            print(parser.get_error(error))


inputs = [network.get_input(i) for i in range(network.num_inputs)]
outputs = [network.get_output(i) for i in range(network.num_outputs)]
print(outputs[0])

profile = builder.create_optimization_profile()

min_shape = [1, 1, 2]
opt_shape = [1, 1024, 2]
max_shape = [1, 2048, 2]
profile.set_shape(inputs[0].name, min_shape, opt_shape, max_shape)

min_shape = [1, 1]
opt_shape = [1, 1024]
max_shape = [1, 2048]
profile.set_shape(inputs[1].name, min_shape, opt_shape, max_shape)

min_shape = [1, 256, 1]
opt_shape = [1, 256, 1024]
max_shape = [1, 256, 2048]
profile.set_shape(inputs[2].name, min_shape, opt_shape, max_shape)


min_shape = [1, 1, 2]
opt_shape = [1, 1024, 2]
max_shape = [1, 2048, 2]
profile.set_shape(inputs[3].name, min_shape, opt_shape, max_shape)

min_shape = [1, 1]
opt_shape = [1, 1024]
max_shape = [1, 2048]
profile.set_shape(inputs[4].name, min_shape, opt_shape, max_shape)


min_shape = [1, 256, 1]
opt_shape = [1, 256, 1024]
max_shape = [1, 256, 2048]
profile.set_shape(inputs[5].name, min_shape, opt_shape, max_shape)


config.add_optimization_profile(profile)


config.get_calibration_profile()


# Check if fast Half is avaliable
# print(builder.platform_has_fast_fp16)


config.set_flag(trt.BuilderFlag.FP16)

# Build engine
engine_bytes = builder.build_serialized_network(network, config)

engine_path = "superglue.engine"
with open(engine_path, "wb") as f:
    f.write(engine_bytes)

inference.py:

import numpy as np
import tensorrt as trt
from cuda import cuda, cudart
import ctypes
from typing import Optional, List
import torch




def check_cuda_err(err):
    if isinstance(err, cuda.CUresult):
        if err != cuda.CUresult.CUDA_SUCCESS:
            raise RuntimeError("Cuda Error: {}".format(err))
    if isinstance(err, cudart.cudaError_t):
        if err != cudart.cudaError_t.cudaSuccess:
            raise RuntimeError("Cuda Runtime Error: {}".format(err))
    else:
        raise RuntimeError("Unknown error type: {}".format(err))

def cuda_call(call):
    err, res = call[0], call[1:]
    check_cuda_err(err)
    if len(res) == 1:
        res = res[0]
    return res



class HostDeviceMem:
    """Pair of host and device memory, where the host memory is wrapped in a numpy array"""
    def __init__(self, size: int, dtype: np.dtype):
        nbytes = size * dtype.itemsize
        host_mem = cuda_call(cudart.cudaMallocHost(nbytes))
        pointer_type = ctypes.POINTER(np.ctypeslib.as_ctypes_type(dtype))

        self._host = np.ctypeslib.as_array(ctypes.cast(host_mem, pointer_type), (size,))
        self._device = cuda_call(cudart.cudaMalloc(nbytes))
        self._nbytes = nbytes

    @property
    def host(self) -> np.ndarray:
        return self._host

    @host.setter
    def host(self, arr: np.ndarray):
        if arr.size > self.host.size:
            raise ValueError(
                f"Tried to fit an array of size {arr.size} into host memory of size {self.host.size}"
            )
        #np.copyto(self.host[:arr.size], arr.flat, casting='safe')
        np.copyto(self.host[:arr.size], arr.flat)

    @property
    def device(self) -> int:
        return self._device

    @property
    def nbytes(self) -> int:
        return self._nbytes

    def __str__(self):
        return f"Host:\n{self.host}\nDevice:\n{self.device}\nSize:\n{self.nbytes}\n"

    def __repr__(self):
        return self.__str__()

    def free(self):
        cuda_call(cudart.cudaFree(self.device))
        cuda_call(cudart.cudaFreeHost(self.host.ctypes.data))


# Allocates all buffers required for an engine, i.e. host/device inputs/outputs.
# If engine uses dynamic shapes, specify a profile to find the maximum input & output size.
def allocate_buffers(engine: trt.ICudaEngine, inputs_shape):
    inputs = []
    outputs = []
    bindings = []
    stream = cuda_call(cudart.cudaStreamCreate())
    tensor_names = [engine.get_tensor_name(i) for i in range(engine.num_io_tensors)]
    print("Tensor Names:", tensor_names)
    for binding in tensor_names:
        if engine.get_tensor_mode(binding) == trt.TensorIOMode.INPUT:
            # Correctly an input
            shape = engine.get_binding_shape(binding)
            print("Input Shape for binding index", binding, ":", shape)
        else:
            # It's an output, handle accordingly
            shape = engine.get_binding_shape(binding)
            print("Output at binding index", binding, ":", shape)
    for shape, binding in zip(inputs_shape, tensor_names):
        size = trt.volume(shape)
        #dtype = np.float32
       # Get tensor data type
        dtype = np.dtype(trt.nptype(engine.get_tensor_dtype(binding)))
        print(dtype)

        # Allocate host and device buffers
        bindingMemory = HostDeviceMem(size, dtype)

        # Append the device buffer to device bindings.
        bindings.append(int(bindingMemory.device))

        # Append to the appropriate list.
        if engine.get_tensor_mode(binding) == trt.TensorIOMode.INPUT:
            inputs.append(bindingMemory)
        else:
            print("injayim")
            print(dtype)
            outputs.append(bindingMemory)
    # for binding in tensor_names:
    #     # Debug: Print tensor name and mode
    #     print("Tensor:", binding)
    #     print("Mode:", engine.get_tensor_mode(binding))

    #     # Get tensor shape
    #     shape = engine.get_binding_shape(binding) if profile_idx is None else engine.get_tensor_profile_shape(binding, profile_idx)[-1]
    #     print("injast" , shape)
    #     #shape = engine.get_tensor_profile_shape(binding, profile_idx)[-1]

    #     print("Shape:", shape)

    #     # Ensure shape is valid
    #     shape_valid = np.all([s >= 0 for s in shape])
    #     if not shape_valid and profile_idx is None:
    #         raise ValueError(f"Binding {binding} has dynamic shape, " +\
    #             "but no profile was specified.")

    #     # Calculate buffer size
    #     size = trt.volume(shape)
    #     if engine.has_implicit_batch_dimension:
    #         print(engine.max_batch_size)
    #         size *= engine.max_batch_size
    #     print("Buffer Size:", size)

    #     # Get tensor data type
    #     dtype = np.dtype(trt.nptype(engine.get_tensor_dtype(binding)))

    #     # Allocate host and device buffers
    #     bindingMemory = HostDeviceMem(size, dtype)

    #     # Append the device buffer to device bindings.
    #     bindings.append(int(bindingMemory.device))

    #     # Append to the appropriate list.
    #     if engine.get_tensor_mode(binding) == trt.TensorIOMode.INPUT:
    #         inputs.append(bindingMemory)
    #     else:
    #         outputs.append(bindingMemory)

    return inputs, outputs, bindings, stream


# Frees the resources allocated in allocate_buffers
def free_buffers(inputs: List[HostDeviceMem], outputs: List[HostDeviceMem], stream: cudart.cudaStream_t):
    for mem in inputs + outputs:
        mem.free()
    cuda_call(cudart.cudaStreamDestroy(stream))


# Wrapper for cudaMemcpy which infers copy size and does error checking
def memcpy_host_to_device(device_ptr: int, host_arr: np.ndarray):
    nbytes = host_arr.size * host_arr.itemsize
    cuda_call(cudart.cudaMemcpy(device_ptr, host_arr, nbytes, cudart.cudaMemcpyKind.cudaMemcpyHostToDevice))


# Wrapper for cudaMemcpy which infers copy size and does error checking
def memcpy_device_to_host(host_arr: np.ndarray, device_ptr: int):
    nbytes = host_arr.size * host_arr.itemsize
    cuda_call(cudart.cudaMemcpy(host_arr, device_ptr, nbytes, cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost))


def _do_inference_base(inputs, outputs, stream, execute_async):
    # Transfer input data to the GPU.
    kind = cudart.cudaMemcpyKind.cudaMemcpyHostToDevice
    [cuda_call(cudart.cudaMemcpyAsync(inp.device, inp.host, inp.nbytes, kind, stream)) for inp in inputs]
    # Run inference.
    execute_async()
    # Transfer predictions back from the GPU.
    kind = cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost
    [cuda_call(cudart.cudaMemcpyAsync(out.host, out.device, out.nbytes, kind, stream)) for out in outputs]
    # Synchronize the stream
    cuda_call(cudart.cudaStreamSynchronize(stream))
    # Return only the host outputs.
    return [out.host for out in outputs]


# This function is generalized for multiple inputs/outputs.
# inputs and outputs are expected to be lists of HostDeviceMem objects.
def do_inference(context, bindings, inputs, outputs, stream, batch_size=1):
    def execute_async():
        context.execute_async(batch_size=batch_size, bindings=bindings, stream_handle=stream)
    return _do_inference_base(inputs, outputs, stream, execute_async)


# This function is generalized for multiple inputs/outputs for full dimension networks.
# inputs and outputs are expected to be lists of HostDeviceMem objects.
def do_inference_v2(context, bindings, inputs, outputs, stream):
    def execute_async():
        context.execute_async_v2(bindings=bindings, stream_handle=stream)
    return _do_inference_base(inputs, outputs, stream, execute_async)

    
kpts0 = np.random.randint(0,255,(1,14,2))/255
scores0 = np.random.randint(0,255,(1,1293))/255
desc0 = np.random.randint(0,255,(1,256,1293))/255
kpts1 = np.random.randint(0,255,(1,1246,2))/255
scores1 = np.random.randint(0,255,(1,1246))/255
desc1 = np.random.randint(0,255,(1,256,1246))/255


# # Example output for keypoints, scores, and descriptors for Image 0
# kpts0 = np.array([[[10, 20], [30, 40], [50, 60]]])  # Example keypoints for Image 0
# scores0 = np.array([[0.9, 0.8, 0.7]])  # Example scores for keypoints in Image 0
# desc0 = np.random.rand(1, 256, 3)  # Example descriptors for keypoints in Image 0
# print(kpts0.shape, scores0.shape, desc0.shape)

# # Example output for keypoints, scores, and descriptors for Image 1
# kpts1 = np.array([[[15, 25], [35, 45], [55, 65]]])  # Example keypoints for Image 1
# scores1 = np.array([[0.85, 0.75, 0.65]])  # Example scores for keypoints in Image 1
# desc1 = np.random.rand(1, 256, 3)  # Example descriptors for keypoints in Image 1
# print(kpts1.shape, scores1.shape, desc1.shape)
kpts0 = torch.load('/home/jetson/Downloads/keypoints.pt', map_location = torch.device('cpu'))
kpts0 = kpts0.numpy()

scores0 = torch.load('/home/jetson/Downloads/keypoint_scores.pt', map_location = torch.device('cpu'))
scores0 = scores0.numpy()

desc0 = torch.load('/home/jetson/Downloads/descriptors.pt', map_location = torch.device('cpu'))
desc0 = desc0.numpy()
desc0 = desc0.transpose(0, 2, 1)



kpts1 = torch.load('/home/jetson/Downloads/keypoints.pt', map_location = torch.device('cpu'))
kpts1 = kpts1.numpy()

scores1 = torch.load('/home/jetson/Downloads/keypoint_scores.pt', map_location = torch.device('cpu'))
scores1 = scores1.numpy()

desc1 = torch.load('/home/jetson/Downloads/descriptors.pt', map_location = torch.device('cpu'))
desc1 = desc1.numpy()
desc1 = desc1.transpose(0, 2, 1)

# permute_descriptor = descriptors[0].permute(1, 0)


# kpts0 = torch.unsqueeze(keypoints[0], 0).numpy()
# scores0 = torch.unsqueeze(scores[0], 0).numpy()
# desc0 = torch.unsqueeze(permute_descriptor, 0).numpy()
# kpts1 = torch.unsqueeze(keypoints[0], 0).numpy()
# scores1 = torch.unsqueeze(scores[0], 0).numpy()
# desc1 = torch.unsqueeze(permute_descriptor, 0).numpy()


# kpts0 = kpts0.astype(np.float32)  # Ensure data is in float32 format
# scores0 = scores0.astype(np.float32)  # Ensure data is in float32 format
# desc0 = desc0.astype(np.float32)  # Ensure data is in float32 format
# kpts1 = kpts1.astype(np.float32)  # Ensure data is in float32 format
# scores1 = scores1.astype(np.float32)  # Ensure data is in float32 format
# desc1 = desc1.astype(np.float32)  # Ensure data is in float32 format


# Function to load a TensorRT engine from a file
def load_engine(engine_file_path):
    TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
    with open(engine_file_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
        return runtime.deserialize_cuda_engine(f.read())

# Load the engine
engine_file_path = "/home/jetson/a/TensorRT/superglue.engine"
engine = load_engine(engine_file_path)



context = engine.create_execution_context()
input_binding_index = engine.get_binding_index("keypoints_0") 
context.set_binding_shape(input_binding_index, kpts0.shape)
input_binding_index = engine.get_binding_index("scores_0") 
context.set_binding_shape(input_binding_index, scores0.shape)
input_binding_index = engine.get_binding_index("descriptors_0")
context.set_binding_shape(input_binding_index, desc0.shape)
input_binding_index = engine.get_binding_index("keypoints_1")
context.set_binding_shape(input_binding_index, kpts1.shape)
input_binding_index = engine.get_binding_index("scores_1")  
context.set_binding_shape(input_binding_index, scores1.shape)
input_binding_index = engine.get_binding_index("descriptors_1")  
context.set_binding_shape(input_binding_index, desc1.shape)
inputs_shape = [kpts0.shape, scores0.shape, desc0.shape, kpts1.shape, scores1.shape, desc1.shape, (1, 100, 100)]
# y = engine.get_binding_index("scores")
# print("here", context.get_tensor_shape(0))
# Allocate memory for inputs and outputs
print("before allocate")
inputs, outputs, bindings, stream = allocate_buffers(engine, inputs_shape)
print('outputs: ', outputs)

print("after allocate")




output_data = do_inference_v2(context, bindings, inputs, outputs, stream)

# Process the output (example)
print("Output:", output_data[0])


# Free allocated memory
free_buffers(inputs, outputs, stream)
1 Like

Hi @octagpt01 ,
Can you please share the verbose logs with us?

Thanks

this is build_engine’s output:

[06/20/2024-12:46:33] [TRT] [I] [MemUsageChange] Init CUDA: CPU +13, GPU +0, now: CPU 36, GPU 3469 (MiB)
[06/20/2024-12:46:40] [TRT] [I] [MemUsageChange] Init builder kernel library: CPU +1154, GPU +1225, now: CPU 1226, GPU 4640 (MiB)
[06/20/2024-12:46:40] [TRT] [W] onnx2trt_utils.cpp:372: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.
<tensorrt.tensorrt.ITensor object at 0xffffa9723730>
[06/20/2024-12:46:40] [TRT] [W] DLA requests all profiles have same min, max, and opt value. All dla layers are falling back to GPU
[06/20/2024-12:46:41] [TRT] [I] Graph optimization time: 0.65073 seconds.
[06/20/2024-12:46:41] [TRT] [I] Global timing cache in use. Profiling results in this builder pass will be stored.
[06/20/2024-12:53:34] [TRT] [I] Detected 6 inputs and 1 output network tensors.
[06/20/2024-12:53:37] [TRT] [I] Total Host Persistent Memory: 931472
[06/20/2024-12:53:37] [TRT] [I] Total Device Persistent Memory: 36864
[06/20/2024-12:53:37] [TRT] [I] Total Scratch Memory: 281018368
[06/20/2024-12:53:37] [TRT] [I] [MemUsageStats] Peak memory usage of TRT CPU/GPU memory allocators: CPU 72 MiB, GPU 284 MiB
[06/20/2024-12:53:37] [TRT] [I] [BlockAssignment] Started assigning block shifts. This will take 1198 steps to complete.
[06/20/2024-12:53:37] [TRT] [I] [BlockAssignment] Algorithm ShiftNTopDown took 392.288ms to assign 22 blocks to 1198 nodes requiring 311439872 bytes.
[06/20/2024-12:53:37] [TRT] [I] Total Activation Memory: 311435776
[06/20/2024-12:53:37] [TRT] [W] TensorRT encountered issues when converting weights between types and that could affect accuracy.
[06/20/2024-12:53:37] [TRT] [W] If this is not the desired behavior, please modify the weights or retrain with regularization to adjust the magnitude of the weights.
[06/20/2024-12:53:37] [TRT] [W] Check verbose logs for the list of affected weights.
[06/20/2024-12:53:37] [TRT] [W] - 200 weights are affected by this issue: Detected subnormal FP16 values.
[06/20/2024-12:53:37] [TRT] [W] - 36 weights are affected by this issue: Detected values less than smallest positive FP16 subnormal value and converted them to the FP16 minimum subnormalized value.
[06/20/2024-12:53:37] [TRT] [I] [MemUsageChange] TensorRT-managed allocation in building engine: CPU +45, GPU +32, now: CPU 45, GPU 32 (MiB)

and this is inference_dynamic’s output:

[06/20/2024-13:06:02] [TRT] [W] Using an engine plan file across different models of devices is not recommended and is likely to affect performance or even cause errors.
/home/jetson/Danial/a/TensorRT/trt/inference_dynamic.py:277: DeprecationWarning: Use get_tensor_name instead.
  input_binding_index = engine.get_binding_index("keypoints_0")
/home/jetson/Danial/a/TensorRT/trt/inference_dynamic.py:278: DeprecationWarning: Use set_input_shape instead.
  context.set_binding_shape(input_binding_index, kpts0.shape)
/home/jetson/Danial/a/TensorRT/trt/inference_dynamic.py:279: DeprecationWarning: Use get_tensor_name instead.
  input_binding_index = engine.get_binding_index("scores_0")
/home/jetson/Danial/a/TensorRT/trt/inference_dynamic.py:280: DeprecationWarning: Use set_input_shape instead.
  context.set_binding_shape(input_binding_index, scores0.shape)
/home/jetson/Danial/a/TensorRT/trt/inference_dynamic.py:281: DeprecationWarning: Use get_tensor_name instead.
  input_binding_index = engine.get_binding_index("descriptors_0")
/home/jetson/Danial/a/TensorRT/trt/inference_dynamic.py:282: DeprecationWarning: Use set_input_shape instead.
  context.set_binding_shape(input_binding_index, desc0.shape)
/home/jetson/Danial/a/TensorRT/trt/inference_dynamic.py:283: DeprecationWarning: Use get_tensor_name instead.
  input_binding_index = engine.get_binding_index("keypoints_1")
/home/jetson/Danial/a/TensorRT/trt/inference_dynamic.py:284: DeprecationWarning: Use set_input_shape instead.
  context.set_binding_shape(input_binding_index, kpts1.shape)
/home/jetson/Danial/a/TensorRT/trt/inference_dynamic.py:285: DeprecationWarning: Use get_tensor_name instead.
  input_binding_index = engine.get_binding_index("scores_1")
/home/jetson/Danial/a/TensorRT/trt/inference_dynamic.py:286: DeprecationWarning: Use set_input_shape instead.
  context.set_binding_shape(input_binding_index, scores1.shape)
/home/jetson/Danial/a/TensorRT/trt/inference_dynamic.py:287: DeprecationWarning: Use get_tensor_name instead.
  input_binding_index = engine.get_binding_index("descriptors_1")
/home/jetson/Danial/a/TensorRT/trt/inference_dynamic.py:288: DeprecationWarning: Use set_input_shape instead.
  context.set_binding_shape(input_binding_index, desc1.shape)
before allocate
Tensor Names: ['keypoints_0', 'scores_0', 'descriptors_0', 'keypoints_1', 'scores_1', 'descriptors_1', 'scores']
/home/jetson/Danial/a/TensorRT/trt/inference_dynamic.py:85: DeprecationWarning: Use get_tensor_shape instead.
  shape = engine.get_binding_shape(binding)
Input Shape for binding index keypoints_0 : (1, -1, 2)
Input Shape for binding index scores_0 : (1, -1)
Input Shape for binding index descriptors_0 : (1, 256, -1)
Input Shape for binding index keypoints_1 : (1, -1, 2)
Input Shape for binding index scores_1 : (1, -1)
Input Shape for binding index descriptors_1 : (1, 256, -1)
/home/jetson/Danial/a/TensorRT/trt/inference_dynamic.py:89: DeprecationWarning: Use get_tensor_shape instead.
  shape = engine.get_binding_shape(binding)
Output at binding index scores : (1, -1, -1)
float32
float32
float32
float32
float32
float32
float32
injayim
float32
outputs:  [Host:
[0. 0. 0. ... 0. 0. 0.]
Device:
8671031296
Size:
40000
]
after allocate
Output: [nan nan nan ... nan nan nan]