TensorRT Custom RoiAlign plugin is very slow

Hi,

I have an object detection pipeline in PyTorch that I converted to TensorRT 7. The pipeline uses a custom RoiALign for which I created a custom plugin. The same CUDA kernel runs in both Pytorch and TensorRT, however, the TensorRT execution is VERY slow.

Here are nvprof results:
TRT
89.91% 52.8266s 100 528.27ms 497.70ms 563.38ms void RoIAlignForward(int, float const *, float, int, int, int, int, int, int, float const , float)
Pytorch
6.47% 584.13ms 101 5.7835ms 5.6158ms 7.1689ms void RoIAlignForward(int, float const *, float, int, int, int, int, int, int, float const , float)

Using the TensorRT profiler I get following consistent results:

(Unnamed Layer* 217) [TopK] → | size = 100 | avg = 1.36272
(Unnamed Layer* 218) [PluginV2IOExt] → | size = 100 | avg = 7.39951
(Unnamed Layer* 219) [Slice] → | size = 100 | avg = 0.00574112
(Unnamed Layer* 22) [Convolution] + (Unnamed Layer* 24) [ElementWise] + (Unnamed Layer* 25) [Activation] → | size = 100 | avg = 0.251393
(Unnamed Layer* 220) [Slice] → | size = 100 | avg = 0.00396224
(Unnamed Layer* 222) [Identity] → | size = 100 | avg = 0.00621984
(Unnamed Layer* 223) [Constant] → | size = 100 | avg = 0.00169408
(Unnamed Layer* 224) [Identity] → | size = 100 | avg = 0.00405248
(Unnamed Layer* 227) [Shuffle] + (Unnamed Layer* 228) [Shuffle] → | size = 100 | avg = 0.0091824
(Unnamed Layer* 229) [Shuffle] → | size = 100 | avg = 0.0483347
(Unnamed Layer* 230) [Shuffle] + (Unnamed Layer* 231) [Shuffle] → | size = 100 | avg = 0.0103882
(Unnamed Layer* 232) [Shuffle] → | size = 100 | avg = 0.0459472
(Unnamed Layer* 233) [PluginV2IOExt] → | size = 100 | avg = 528.319
(Unnamed Layer* 234) [Shuffle] → | size = 100 | avg = 33.8268

The last PluginV2IOExt is the RoiAlign layer.
I don’t understand why it is slower in TensorRT. I checked that the same parameters are used.
Also, the layer right after the ROIAlign is suspiciously slow. I, suspect this is some memory access
issue. How can I fix it?

Below is the code for the CUDA kernel in ROI Align that iexactly the same in Pytorch and TensorRT.

#define CUDA_1D_KERNEL_LOOP(i, n)
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
i += blockDim.x * gridDim.x)

template
device T bilinear_interpolate(const T* bottom_data,
const int height, const int width,
T y, T x,
const int index /* index for debug only*/) {

// deal with cases that inverse elements are out of feature map boundary
if (y < -1.0 || y > height || x < -1.0 || x > width) {
//empty
return 0;
}

if (y <= 0) y = 0;
if (x <= 0) x = 0;

int y_low = (int) y;
int x_low = (int) x;
int y_high;
int x_high;

if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = (T) y_low;
} else {
y_high = y_low + 1;
}

if (x_low >= width - 1) {
x_high = x_low = width - 1;
x = (T) x_low;
} else {
x_high = x_low + 1;
}

T ly = y - y_low;
T lx = x - x_low;
T hy = 1. - ly, hx = 1. - lx;
// do bilinear interpolation
T v1 = bottom_data[y_low * width + x_low];
T v2 = bottom_data[y_low * width + x_high];
T v3 = bottom_data[y_high * width + x_low];
T v4 = bottom_data[y_high * width + x_high];
T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;

T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);

return val;
}

template
global void RoIAlignForward(const int nthreads, const T* bottom_data,
const T spatial_scale, const int channels,
const int height, const int width,
const int pooled_height, const int pooled_width,
const int sampling_ratio,
const T* bottom_rois, T* top_data) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
// (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;

const T* offset_bottom_rois = bottom_rois + n * 5;
int roi_batch_ind = offset_bottom_rois[0];

// Do not using rounding; this implementation detail is critical
T roi_start_w = offset_bottom_rois[1] * spatial_scale;
T roi_start_h = offset_bottom_rois[2] * spatial_scale;
T roi_end_w = offset_bottom_rois[3] * spatial_scale;
T roi_end_h = offset_bottom_rois[4] * spatial_scale;
// T roi_start_w = round(offset_bottom_rois[1] * spatial_scale);
// T roi_start_h = round(offset_bottom_rois[2] * spatial_scale);
// T roi_end_w = round(offset_bottom_rois[3] * spatial_scale);
// T roi_end_h = round(offset_bottom_rois[4] * spatial_scale);

// Force malformed ROIs to be 1x1
T roi_width = max(roi_end_w - roi_start_w, (T)1.);
T roi_height = max(roi_end_h - roi_start_h, (T)1.);
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);

const T* offset_bottom_data = bottom_data + (roi_batch_ind * channels + c) * height * width;

// We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2
int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);

// We do average (integral) pooling inside a bin
const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4

T output_val = 0.;
for (int iy = 0; iy < roi_bin_grid_h; iy ++) // e.g., iy = 0, 1
{
  const T y = roi_start_h + ph * bin_size_h + static_cast<T>(iy + .5f) * bin_size_h / static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
  for (int ix = 0; ix < roi_bin_grid_w; ix ++)
  {
    const T x = roi_start_w + pw * bin_size_w + static_cast<T>(ix + .5f) * bin_size_w / static_cast<T>(roi_bin_grid_w);

    T val = bilinear_interpolate(offset_bottom_data, height, width, y, x, index);
    output_val += val;
  }
}
output_val /= count;

top_data[index] = output_val;

}
}

int RoiAlign::enqueue(int batchSize, const void *const *inputs, void **outputs, void *workspace, cudaStream_t stream)
{
const int output_size = K * C * output_height * output_width;

dim3 grid(std::min(static_cast<long>(ceil(output_size/512L)), 4096L));
dim3 block(512);

auto feat = static_cast<const float*>(inputs[0]); 
auto rois = static_cast<const float*>(inputs[1]); 
auto output = static_cast<float*>(outputs[0]); 

auto startTime = std::chrono::high_resolution_clock::now();
RoIAlignForward<float><<<grid, block, 0, stream>>>(
    output_size,                                  // nthreads
    feat,                                         // bottom_data
    spatial_scale,                                // spatial_scale
    C,                                            // channels
    H,                                            // height
    W,                                            // width
    output_height,                                // pooled_height
    output_width,                                 // pooled_width
    sampling_ratio,                               // sampling_ratio
    rois,                                         // bottom_rois
    output);                                      // top_data

cudaStreamSynchronize(stream);
auto endTime = std::chrono::high_resolution_clock::now();
float totalTime = std::chrono::duration<float, std::milli>(endTime - startTime).count();
cout << "\ntotalTime = " << totalTime << "\n";

return 0;

}

Have you tried calibrate your model to int8 mode?

Hi, I have not tried that, but I don’t think it would solve the essence of the problem, i.e. why is the execution time so different.

I have run nvprof on the custom kernel with --metrics all on both pytorch and TRT. There are afew differences, one that I find telling is:

PyTorch
3 sysmem_read_transactions System Memory Read Transactions 0 0 0
3 sysmem_write_transactions System Memory Write Transactions 5 5 5

TRT
3 sysmem_read_transactions System Memory Read Transactions 7840000 7840000 7840000
3 sysmem_write_transactions System Memory Write Transactions 5 5 5

In TensorRT the kenrel needs access to system (host?) memory. Why is this and how can I stop this?

The problem was that the input to the RoiAlign layer was also an output from the net.

You can generate a tiny model that only includes RoiAlign, and profiling the corresponding trt engine.

Hi, thanks for the feedback. The RoiAlign speed problem was solved, as I mentioned.
The reason for the poor performance was that the rois input to the layer was also an output fromthe net that was being stored to system memory. The RoiAlign layer was then reading all data form the memory.

Another problem has arisen, I am not starting a new thread for it because it relates to the same design, even though not specifically to the RoiAlign layer.

I am very grateful for any feedback, hopefully, someone has experience with TensorRT for FasterRCNN-like models and can help resolve the issue.

All layers processing output from the ROiALign module are very slow. I attach output from the TRT profiler:

(Unnamed Layer* 224) [PluginV2IOExt] → | size = 5 | avg = 6.3378
(Unnamed Layer* 226) [Shuffle] → | size = 5 | avg = 0.0104896
(Unnamed Layer* 227) [Convolution] + (Unnamed Layer* 229) [Activation] → | size = 5 | avg = 5.02792
(Unnamed Layer* 230) [Convolution] + (Unnamed Layer* 232) [Activation] → | size = 5 | avg = 6.94584
(Unnamed Layer* 233) [Convolution] + (Unnamed Layer* 237) [ElementWise] + (Unnamed Layer* 238) [Activation] → | size = 5 | avg = 3.51948
(Unnamed Layer* 235) [Convolution] → | size = 5 | avg = 8.87447
(Unnamed Layer* 239) [Convolution] + (Unnamed Layer* 241) [Activation] → | size = 5 | avg = 3.13203
(Unnamed Layer* 242) [Convolution] + (Unnamed Layer* 244) [Activation] → | size = 5 | avg = 6.88691
(Unnamed Layer* 245) [Convolution] + (Unnamed Layer* 247) [ElementWise] + (Unnamed Layer* 248) [Activation] → | size = 5 | avg = 3.45716
(Unnamed Layer* 249) [Convolution] + (Unnamed Layer* 251) [Activation] → | size = 5 | avg = 3.08855
(Unnamed Layer* 252) [Convolution] + (Unnamed Layer* 254) [Activation] → | size = 5 | avg = 6.8191
(Unnamed Layer* 255) [Convolution] + (Unnamed Layer* 257) [ElementWise] + (Unnamed Layer* 258) [Activation] → | size = 5 | avg = 3.49404
(Unnamed Layer* 259) [Reduce] → | size = 5 | avg = 14.1211
(Unnamed Layer* 260) [Reduce] → | size = 5 | avg = 2.92504

These layers are very slow, by comparison other layers in the network are much faster:
(Unnamed Layer* 224) [PluginV2IOExt] → | size = 5 | avg = 6.3378
(Unnamed Layer* 226) [Shuffle] → | size = 5 | avg = 0.0104896
(Unnamed Layer* 227) [Convolution] + (Unnamed Layer* 229) [Activation] → | size = 5 | avg = 5.02792
(Unnamed Layer* 230) [Convolution] + (Unnamed Layer* 232) [Activation] → | size = 5 | avg = 6.94584
(Unnamed Layer* 233) [Convolution] + (Unnamed Layer* 237) [ElementWise] + (Unnamed Layer* 238) [Activation] → | size = 5 | avg = 3.51948
(Unnamed Layer* 235) [Convolution] → | size = 5 | avg = 8.87447
(Unnamed Layer* 239) [Convolution] + (Unnamed Layer* 241) [Activation] → | size = 5 | avg = 3.13203
(Unnamed Layer* 242) [Convolution] + (Unnamed Layer* 244) [Activation] → | size = 5 | avg = 6.88691
(Unnamed Layer* 245) [Convolution] + (Unnamed Layer* 247) [ElementWise] + (Unnamed Layer* 248) [Activation] → | size = 5 | avg = 3.45716
(Unnamed Layer* 249) [Convolution] + (Unnamed Layer* 251) [Activation] → | size = 5 | avg = 3.08855
(Unnamed Layer* 252) [Convolution] + (Unnamed Layer* 254) [Activation] → | size = 5 | avg = 6.8191
(Unnamed Layer* 255) [Convolution] + (Unnamed Layer* 257) [ElementWise] + (Unnamed Layer* 258) [Activation] → | size = 5 | avg = 3.49404
(Unnamed Layer* 259) [Reduce] → | size = 5 | avg = 14.1211
(Unnamed Layer* 26) [Convolution] + (Unnamed Layer* 28) [Activation] → | size = 5 | avg = 0.14816
(Unnamed Layer* 260) [Reduce] → | size = 5 | avg = 2.92504

It is very suspicious that a layer like Reduce is so slow. I don’ think the problem is Reduce specific, but I looked closer into it. 2 different kernels are used:

14.31% 70.591ms 5 14.118ms 14.062ms 14.169ms void cuReduceLayer::tailReduceFast<unsigned int=32, nvinfer1::ReduceOp, float, float>(float*, cuReduceLayer::tailReduceFast<unsigned int=32, nvinfer1::ReduceOp, float, float> const *, cuReduceLayer::LaunchParams)

2.96% 14.592ms 5 2.9184ms 2.7397ms 3.2506ms void cuReduceLayer::tailReduce3DH<unsigned int=32, nvinfer1::ReduceOp, float, float>(float*, cuReduceLayer::tailReduce3DH<unsigned int=32, nvinfer1::ReduceOp, float, float> const *, cuReduceLayer::LaunchParams)

I analyzed the tailReduceFast and nothing specifically has caught my eye yet, except maybe the poor memory load efficiency:

      4                            gld_efficiency                                         Global Memory Load Efficiency      25.00%      25.00%      25.00%
      4                            gst_efficiency                                        Global Memory Store Efficiency     100.00%     100.00%     100.

Hi,

Could you please provide details on the platforms you are using so we can better help:
o Linux distro and version
o GPU type
o Nvidia driver version
o CUDA version
o CUDNN version
o Python version [if using python]
o Tensorflow and PyTorch version
o TensorRT version
If possible, please share the script & model file to reproduce the issue.

Thanks

Here is the clarification for the problems I had:

  1. The layers after RoiAlign are slower, due to higher computational demands. This is consistent with PyTorch running time and can be mitigated by lowering the number of Rois used.
  2. The very inefficient Reduce layers appear due to following PyTorch code:
    features.mean(3).mean(2). If I change that to features.mean((2,3)), TRT running time is reduced significantly. Pytorch runnning time is not affected.

Hi @eascheiber,

Can you please provide the TRT Plugin Code for Registration? Would be a great help from your side.

Thanks

Hi, Request you to check the below reference links for custom plugin implementation.
https://github.com/NVIDIA/TensorRT/tree/master/samples/opensource/sampleOnnxMnistCoordConvAC

Thanks!