Inference Speed

On the Jetson Xavier NX development board, I tried converting my PyTorch ResNet-50 model to TensorRT and quantified using FP16 and INT8 respectively. But when I tried to inference in these modes, I found that the speed was the same, which puzzled me. Is this normal?

Hi,

Do you convert the model with trtexec?
If yes, could you share the profiling output with us?

For example:

$ /usr/src/tensorrt/bin/trtexec --onnx=/usr/src/tensorrt/data/resnet50/ResNet50.onnx --fp16
&&&& RUNNING TensorRT.trtexec [TensorRT v8502] # /usr/src/tensorrt/bin/trtexec --onnx=/usr/src/tensorrt/data/resnet50/ResNet50.onnx --fp16
[03/27/2023-10:40:59] [I] === Model Options ===
[03/27/2023-10:40:59] [I] Format: ONNX
[03/27/2023-10:40:59] [I] Model: /usr/src/tensorrt/data/resnet50/ResNet50.onnx
[03/27/2023-10:40:59] [I] Output:
...
[03/27/2023-10:43:03] [I] 
[03/27/2023-10:43:03] [I] === Performance summary ===
[03/27/2023-10:43:03] [I] Throughput: 380.223 qps
[03/27/2023-10:43:03] [I] Latency: min = 2.60376 ms, max = 2.69702 ms, mean = 2.67666 ms, median = 2.67651 ms, percentile(90%) = 2.68518 ms, percentile(95%) = 2.68726 ms, percentile(99%) = 2.69299 ms
[03/27/2023-10:43:03] [I] Enqueue Time: min = 0.357056 ms, max = 0.755615 ms, mean = 0.377569 ms, median = 0.375488 ms, percentile(90%) = 0.386719 ms, percentile(95%) = 0.392639 ms, percentile(99%) = 0.420166 ms
[03/27/2023-10:43:03] [I] H2D Latency: min = 0.0450439 ms, max = 0.0697021 ms, mean = 0.0467225 ms, median = 0.0463867 ms, percentile(90%) = 0.0481567 ms, percentile(95%) = 0.0487061 ms, percentile(99%) = 0.0510254 ms
[03/27/2023-10:43:03] [I] GPU Compute Time: min = 2.55505 ms, max = 2.64722 ms, mean = 2.62598 ms, median = 2.6261 ms, percentile(90%) = 2.63403 ms, percentile(95%) = 2.63647 ms, percentile(99%) = 2.64189 ms
[03/27/2023-10:43:03] [I] D2H Latency: min = 0.00268555 ms, max = 0.00585938 ms, mean = 0.00396002 ms, median = 0.00390625 ms, percentile(90%) = 0.00463867 ms, percentile(95%) = 0.00476074 ms, percentile(99%) = 0.00537109 ms
[03/27/2023-10:43:03] [I] Total Host Walltime: 3.00613 s
[03/27/2023-10:43:03] [I] Total GPU Compute Time: 3.0015 s
[03/27/2023-10:43:03] [I] Explanations of the performance metrics are printed in the verbose logs.
[03/27/2023-10:43:03] [I] 
&&&& PASSED TensorRT.trtexec [TensorRT v8502] # /usr/src/tensorrt/bin/trtexec --onnx=/usr/src/tensorrt/data/resnet50/ResNet50.onnx --fp16

Thanks.

Hi!
I used torch2trt to convert my pytorch model to tensorrt.
This is my convert code. FP16 means fp16_mode=True and int8_mode=False, while INT8 means fp16_mode=False and int8_mode=True.

import os
from loguru import logger
import tensorrt as trt
import torch
from torch2trt import torch2trt

import torchvision
import torch.nn as nn

model = torchvision.models.resnet50()

model = model.cuda()
model.eval()


x = torch.randn(1, 3, 320, 320).cuda()
model_trt = torch2trt(
    model,
    [x],
    fp16_mode=True,
    int8_mode=False,
)

torch.save(model_trt.state_dict(), "resnet18_fp16.pth")
logger.info("Converted TensorRT model done.")
engine_file = 'resnet18_fp16.engine'
with open(engine_file, 'wb') as f:
    f.write(model_trt.engine.serialize())

logger.info("Converted TensorRT model engine file is saved for C++ inference.")

and this is my test code. Just change the file root “/home/marcohan/xieqiang/Endoscope_wyj/resnet50_fp16.engine” to test different models generated by torch2trt.

#include <fstream>
#include <iostream>
#include <sstream>
#include <numeric>
#include <chrono>
#include <vector>
#include <opencv2/opencv.hpp>
#include <dirent.h>
#include <stdio.h>
#include <stdlib.h>
#include "NvInfer.h"
#include "cuda_runtime_api.h"
#include "logging.h"
#include "NvInferPlugin.h"
#include <string>
#include <time.h>

using namespace std;


static const int INPUT_W = 320;
static const int INPUT_H = 320;
const char* INPUT_BLOB_NAME = "input_0";
const char* OUTPUT_BLOB_NAME = "output_0";
static Logger gLogger;
#define DEVICE 0  // GPU id

using namespace nvinfer1;

#define CHECK(status) \
    do\
    {\
        auto ret = (status);\
        if (ret != 0)\
        {\
            std::cerr << "Cuda failure: " << ret << std::endl;\
            abort();\
        }\
    } while (0)




float* blobFromImage(cv::Mat& img) {
    float norm[2][3] = { {0.485, 0.456, 0.406},
                     {0.229, 0.224, 0.225} };
    float* blob = new float[img.total() * 3];
    int channels = 3;
    int img_h = img.rows;
    int img_w = img.cols;
    for (size_t c = 0; c < channels; c++)
    {
        for (size_t h = 0; h < img_h; h++)
        {
            for (size_t w = 0; w < img_w; w++)
            {
                blob[c * img_w * img_h + h * img_w + w] =
                    ((float)img.at<cv::Vec3b>(h, w)[c] / 255 - norm[0][c]) / norm[1][c];
            }
        }
    }
    return blob;
}

void doInference(IExecutionContext & context, float* input, float* output, int batchSize)
{
    const ICudaEngine& engine = context.getEngine();

    // Pointers to input and output device buffers to pass to engine.
    // Engine requires exactly IEngine::getNbBindings() number of buffers.
    assert(engine.getNbBindings() == 2);
    void* buffers[2];

    // In order to bind the buffers, we need to know the names of the input and output tensors.
    // Note that indices are guaranteed to be less than IEngine::getNbBindings()
    const int inputIndex = engine.getBindingIndex(INPUT_BLOB_NAME);
    const int outputIndex = engine.getBindingIndex(OUTPUT_BLOB_NAME);

    // Create GPU buffers on device
    CHECK(cudaMalloc(&buffers[inputIndex], batchSize * 3 * INPUT_H * INPUT_W * sizeof(float)));
    CHECK(cudaMalloc(&buffers[outputIndex], batchSize * OUTPUT_SIZE * sizeof(float)));

    // Create stream
    cudaStream_t stream;
    CHECK(cudaStreamCreate(&stream));

    // DMA input batch data to device, infer on the batch asynchronously, and DMA output back to host
    CHECK(cudaMemcpyAsync(buffers[inputIndex], input, batchSize * 3 * INPUT_H * INPUT_W * sizeof(float), cudaMemcpyHostToDevice, stream));
    context.enqueue(batchSize, buffers, stream, nullptr);   // �����Ľ��������У�ִ������
    CHECK(cudaMemcpyAsync(output, buffers[outputIndex], batchSize * OUTPUT_SIZE * sizeof(float), cudaMemcpyDeviceToHost, stream));  // ��Ԥ�����Ӵӻ�����ȡ��
    cudaStreamSynchronize(stream);      // ����ͬ��

    // Release stream and buffers
    cudaStreamDestroy(stream);
    CHECK(cudaFree(buffers[inputIndex]));
    CHECK(cudaFree(buffers[outputIndex]));
}




void GetFileNames(string path,vector<string>& filenames)
{
    DIR *pDir;
    struct dirent* ptr;
    if(!(pDir = opendir(path.c_str()))){
        cout<<"Folder doesn't Exist!"<<endl;
        return;
    }
    while((ptr = readdir(pDir))!=0) {
        if (strcmp(ptr->d_name, ".") != 0 && strcmp(ptr->d_name, "..") != 0){
            filenames.push_back(ptr->d_name);
    }
    }
    closedir(pDir);
}



int main(int argc, char** argv)
{

    cudaSetDevice(DEVICE);
    char* trtModelStream{ nullptr };
    size_t size{ 0 };
    initLibNvInferPlugins(&gLogger.getTRTLogger(), "");
    std::ifstream file("/home/marcohan/xieqiang/Endoscope_wyj/resnet50_fp16.engine", std::ios::binary);
    if (file.good()) {
        file.seekg(0, file.end);
        size = file.tellg();
        file.seekg(0, file.beg);
        trtModelStream = new char[size];
        assert(trtModelStream);
        file.read(trtModelStream, size);    // ��ȡengine�ļ���trtModelStream
        file.close();
    }
    

    // trtModelStreamתICudaEngine��������������
    IRuntime* runtime = createInferRuntime(gLogger);
    assert(runtime != nullptr);
    ICudaEngine* engine = runtime->deserializeCudaEngine(trtModelStream, size, nullptr);
    assert(engine != nullptr);
    IExecutionContext* context = engine->createExecutionContext();
    assert(context != nullptr);
    delete[] trtModelStream;

    string path = "data";
    vector<string> files;
    GetFileNames(path, files);
    for (int i = 0; i < files.size(); i++) {
        const std::string input_image_path = path + "/" + files[i];
        // Subtract mean from image
        cv::Mat img = cv::imread(input_image_path);
        cv::cvtColor(img, img, cv::COLOR_BGR2RGB);
        int img_w = img.cols;
        int img_h = img.rows;
        cv::Mat pr_img;
        cv::resize(img, pr_img, cv::Size(INPUT_W, INPUT_H), 0, 0, 2);

        float* blob;
        blob = blobFromImage(pr_img);
        //float scale = std::min(INPUT_W / (img.cols * 1.0), INPUT_H / (img.rows * 1.0));

        // Run inference
        static float prob[OUTPUT_SIZE];

        // timespec t1, t2;
        // int deltaT;
        // clock_gettime(CLOCK_MONOTONIC, &t1);
        auto start = std::chrono::system_clock::now();
        doInference(*context, blob, prob, 1);   // ����ͼ���������
        auto end = std::chrono::system_clock::now();
        // clock_gettime(CLOCK_MONOTONIC, &t2);
        // deltaT = (t2.tv_sec - t1.tv_sec) * 10^9 + t2.tv_nsec - t1.tv_nsec;
        // cout<<deltaT<<endl;

        std::cout << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() << "ms" << std::endl;
        delete blob;
        //cv::Mat out = Array2Mat(pr_img, prob, 320, 320);
        // cv::Mat out = Array2Mat(prob, 320, 320);
        // cv::resize(out, out, cv::Size(img_w, img_h), 0, 0, 2);
        // cv::cvtColor(out, out, cv::COLOR_RGB2BGR);
        /*cv::Mat back(img_w, img_h, CV_8UC3, cv::Scalar(160, 80, 160));
        bitwise_not(out, out);
        img.copyTo(back, out);
        cv::cvtColor(back, back, cv::COLOR_RGB2BGR);*/
        
        // cv::imwrite("result/" + files[i], out);
        
    }

Thank you!

Hi,

Do you get the performance reports from TensorRT?

If not, how do you measure the performance?
Could you share some data with us?

Thanks.

Hi, I guess I gain the correct result on ResNet50 model. But I have another problem, I wanner know whether the result is reasonable on my custom model used to do semantic segmentation, cause the throughput difference between the two conditions is slightly. Here is the result.

fp16 deployment

[03/28/2023-15:59:00] [I] === Performance summary ===
[03/28/2023-15:59:00] [I] Throughput: 37.9287 qps
[03/28/2023-15:59:00] [I] Latency: min = 25.407 ms, max = 45.8411 ms, mean = 26.355 ms, median = 25.4744 ms, percentile(99%) = 45.8411 ms
[03/28/2023-15:59:00] [I] End-to-End Host Latency: min = 25.4133 ms, max = 45.8525 ms, mean = 26.3647 ms, median = 25.4861 ms, percentile(99%) = 45.8525 ms
[03/28/2023-15:59:00] [I] Enqueue Time: min = 24.636 ms, max = 48.5273 ms, mean = 26.4637 ms, median = 25.2899 ms, percentile(99%) = 48.5273 ms
[03/28/2023-15:59:00] [I] H2D Latency: min = 0.0493164 ms, max = 0.0917969 ms, mean = 0.0515442 ms, median = 0.050293 ms, percentile(99%) = 0.0917969 ms
[03/28/2023-15:59:00] [I] GPU Compute Time: min = 25.3403 ms, max = 45.729 ms, mean = 26.286 ms, median = 25.4066 ms, percentile(99%) = 45.729 ms
[03/28/2023-15:59:00] [I] D2H Latency: min = 0.0168457 ms, max = 0.0202637 ms, mean = 0.0174255 ms, median = 0.0170898 ms, percentile(99%) = 0.0202637 ms
[03/28/2023-15:59:00] [I] Total Host Walltime: 1.05461 s
[03/28/2023-15:59:00] [I] Total GPU Compute Time: 1.05144 s
[03/28/2023-15:59:00] [W] * Throughput may be bound by Enqueue Time rather than GPU Compute and the GPU may be under-utilized.
[03/28/2023-15:59:00] [W]   If not already in use, --useCudaGraph (utilize CUDA graphs where possible) may increase the throughput.
[03/28/2023-15:59:00] [I] Explanations of the performance metrics are printed in the verbose logs.
[03/28/2023-15:59:00] [I] 
&&&& PASSED TensorRT.trtexec [TensorRT v8001] # ./trtexec --loadEngine=/home/marcohan/xieqiang/test/136000_fp16.engine
[03/28/2023-15:59:00] [I] [TRT] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +0, GPU +0, now: CPU 1377, GPU 6893 (MiB)

int8 deployment

[03/28/2023-15:59:40] [I] === Performance summary ===
[03/28/2023-15:59:40] [I] Throughput: 42.0215 qps
[03/28/2023-15:59:40] [I] Latency: min = 23.0854 ms, max = 37.5281 ms, mean = 23.7861 ms, median = 23.119 ms, percentile(99%) = 37.5281 ms
[03/28/2023-15:59:40] [I] End-to-End Host Latency: min = 23.0957 ms, max = 37.5479 ms, mean = 23.7967 ms, median = 23.1301 ms, percentile(99%) = 37.5479 ms
[03/28/2023-15:59:40] [I] Enqueue Time: min = 22.7603 ms, max = 70.5415 ms, mean = 24.5922 ms, median = 22.9861 ms, percentile(99%) = 70.5415 ms
[03/28/2023-15:59:40] [I] H2D Latency: min = 0.0493164 ms, max = 0.0778809 ms, mean = 0.0515137 ms, median = 0.0500488 ms, percentile(99%) = 0.0778809 ms
[03/28/2023-15:59:40] [I] GPU Compute Time: min = 23.0186 ms, max = 37.4211 ms, mean = 23.7171 ms, median = 23.0525 ms, percentile(99%) = 37.4211 ms
[03/28/2023-15:59:40] [I] D2H Latency: min = 0.0166016 ms, max = 0.0290527 ms, mean = 0.0174212 ms, median = 0.0170898 ms, percentile(99%) = 0.0290527 ms
[03/28/2023-15:59:40] [I] Total Host Walltime: 0.999487 s
[03/28/2023-15:59:40] [I] Total GPU Compute Time: 0.996119 s
[03/28/2023-15:59:40] [W] * Throughput may be bound by Enqueue Time rather than GPU Compute and the GPU may be under-utilized.
[03/28/2023-15:59:40] [W]   If not already in use, --useCudaGraph (utilize CUDA graphs where possible) may increase the throughput.
[03/28/2023-15:59:40] [I] Explanations of the performance metrics are printed in the verbose logs.
[03/28/2023-15:59:40] [I] 
&&&& PASSED TensorRT.trtexec [TensorRT v8001] # ./trtexec --loadEngine=/home/marcohan/xieqiang/test/136000_int8.engine
[03/28/2023-15:59:40] [I] [TRT] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +0, GPU +0, now: CPU 1382, GPU 6933 (MiB)

Thanks.

There is no update from you for a period, assuming this is not an issue any more.
Hence we are closing this topic. If need further support, please open a new one.
Thanks

Hi,

Just want to make sure first.

Have you maximized the device performance with the following command?

$ sudo nvpmodel -m 0
$ sudo jetson_clocks

Thanks.