Why my inference time is so long when using trtexec - FP16?

Hi @dusty_nv @AastaLLL

This is part of my code. I am measuring the inference time, using the Inception_v1 model, optimised with “trtexec” + FP16. I used tf2onnx to parse the TensorFlow graph.

I have put 2 timers:

  1. To measure the time when sending data from CPU to GPU
  2. Just the inference

QUESTIONS:

  1. Are the results given in seconds, milliseconds, etc? I am multiplying by 1000
  2. Why do inference time varies so much?
  3. Why is it taking too much time? In this GIT, the average seems to be 8ms
    https://github.com/NVIDIA-AI-IOT/tf_to_trt_image_classification

RUN 1
image

RUN2
image

RUN 3
image

SCRIPT

def allocate_buffers(engine, batch_size, data_type):

    """
    This is the function to allocate buffers for input and output in the device (GPU) and host (CPU)
    Args:
      engine : The path to the TensorRT engine. 
      batch_size : The batch size for execution time.
      data_type: The type of the data for input and output, for example trt.float32, np.float32. 
    
    Output:
      h_input: Input in the host (CPU).
      d_input: Input in the device (GPU). 
      h_output: Output in the host (CPU). 
      d_output: Output in the device (GPU). 
      stream: CUDA stream.

    """

    # Determine dimensions and create page-locked memory buffers (which won't be swapped to disk) to hold host inputs/outputs.
    h_input = cuda.pagelocked_empty(batch_size * trt.volume(engine.get_binding_shape(0)), dtype=trt.nptype(data_type))
    h_output = cuda.pagelocked_empty(batch_size * trt.volume(engine.get_binding_shape(1)), dtype=trt.nptype(data_type))
    
    # Allocate device memory for inputs and outputs (the same size as host' input and output).
    d_input = cuda.mem_alloc(h_input.nbytes)
    d_output = cuda.mem_alloc(h_output.nbytes)
    
    # Create a stream in which to copy inputs/outputs between the allocated memory from device and host; and run inference.
    stream = cuda.Stream()
    return h_input, d_input, h_output, d_output, stream




def load_images_to_buffer(pics, pagelocked_buffer):
    
    """
    This is the function to load (preprocessed) images to buffers in the host
    
    """
    preprocessed = np.asarray(pics).ravel()
    np.copyto(pagelocked_buffer, preprocessed)


 

def load_labels(path, encoding='utf-8'):
    """Loads labels from file (with or without index numbers).
    Args:
        path: path to label file.
        encoding: label file encoding.
        
    Returns:
        Dictionary mapping indices to labels.
       
    """
    with open(path, 'r', encoding=encoding) as f:
        lines = f.readlines()
        if not lines:
            return {}

        if lines[0].split(' ', maxsplit=1)[0].isdigit():
            pairs = [line.split(' ', maxsplit=1) for line in lines]
            return {int(index): label.strip() for index, label in pairs}
        else:
            return {index: line.strip() for index, line in enumerate(lines)}
        



def load_engine(trt_runtime, plan_path):
    """
    This function reads the engine from the file .trt and deserializes it
    """
    with open(plan_path, 'rb') as f:
        engine_data = f.read()
    engine = trt_runtime.deserialize_cuda_engine(engine_data)
    return engine

# Modify x in [0:x] to display the numer of desired predictions
def postprocess_inception(output):
    predictions_top = np.argsort(output)[::-1][0:1]
    labels_top = [labels[p] for p in predictions_top]
    scores_top = output[predictions_top]
    return scores_top, labels_top


# Loads labels file
labels = load_labels(labels_path)


# Loads model file
engine = load_engine(trt_runtime, plan_path)


# Allocate buffers in CPU and GPU - calling TX2_classify.py (allocate_buffers) 
h_input, d_input, h_output, d_output, stream = allocate_buffers(engine, 1, trt.float32)

# Context for executing inference using ICudaEngine
context = engine.create_execution_context()


# Image preprocessing
# Open image, convert each pixel to the triple 8-bit value, change size and apply a high-quality
# downsampling filter. Then it create an array of data type FP32 and normalizes it.
image = np.asarray(Image.open(image_path).convert('RGB').resize(size, Image.ANTIALIAS), dtype=np.float32)
image /= 255


# Load the preprocessed image into the CPU's buffer - calling TX2_classify.py (load_images_to_buffer)
load_images_to_buffer(image, h_input)

# Transfer input data from CPU to GPU.
start_0 = time.perf_counter()
cuda.memcpy_htod_async(d_input, h_input, stream)
time_CPUtoGPU = time.perf_counter() - start_0
print("CPUtoGPU(ms):", (time_CPUtoGPU * 1000))
print("\n")


for i in range (count):
    # Run inference.
    #context.profiler = trt.Profiler() ##shows execution time(ms) of each layer
    start_1 = time.perf_counter()
    context.execute(batch_size=1, bindings=[int(d_input), int(d_output)])   
    inference_time = time.perf_counter() - start_1
    print("Inference_time(ms)")
    print(inference_time * 1000)
    print("\n")
    
    # Transfer predictions back from the GPU to the CPU.
    cuda.memcpy_dtoh_async(h_output, d_output, stream)
        
    # Synchronize the stream.
    stream.synchronize()
        
    # Return the CPU output.
    scores = h_output

    pred = postprocess_inception(scores)
    print(pred)
    print("\n")

Hi,

In case you don’t know, pleases maximize the device performance with the following commands first:

$ sudo nvpmodel -m 0
$ sudo jetson_clocks

1. Based on their document below, time.perf_counter() returns the value in seconds.
Since you multiply it with 1000, it should be ms.
https://docs.python.org/3/library/time.html#time.perf_counter

2. Please maximize the clocks rate.
Sometimes inference time will vary if using the dynamic clock mode (for power saving).

3. The time only measures the inference part and do it with fp16 precision.

Thanks.

Hi @AastaLLL

Thank you for the quick answer. Check my results now:

3. The time only measures the inference part and do it with fp16 precision.

I optimised the model with “trtexec” and FP16.
And put a timer only in the inference:

These are the results using: jetson_clocks and nvpmodel -m 0

RUN1
image

RUN2
image

RUN3
image

Why is the inference #1 slower than the others? It is not measuring the time to move data from CPU to GPU. Is there something else happening?

The very first time you run inference on a model after loading the application, it typically takes a bit longer as the GPU is initialized, GPU code is loaded, ect. So during benchmarks, there is typically a “warm-up” period where the inferencing is run but ignored. Also if your model is using plugins for TensorRT, those are configured on-demand which takes more time the first run.

Also, benchmarks typically run many inferencing iterations (like trtexec program does) and average the results or perform some mean/median/mode statistics on them.

1 Like