Hi,
I have an object detection pipeline in PyTorch that I converted to TensorRT 7. The pipeline uses a custom RoiALign for which I created a custom plugin. The same CUDA kernel runs in both Pytorch and TensorRT, however, the TensorRT execution is VERY slow.
Here are nvprof results:
TRT
89.91% 52.8266s 100 528.27ms 497.70ms 563.38ms void RoIAlignForward(int, float const *, float, int, int, int, int, int, int, float const , float)
Pytorch
6.47% 584.13ms 101 5.7835ms 5.6158ms 7.1689ms void RoIAlignForward(int, float const *, float, int, int, int, int, int, int, float const , float)
Using the TensorRT profiler I get following consistent results:
(Unnamed Layer* 217) [TopK] → | size = 100 | avg = 1.36272
(Unnamed Layer* 218) [PluginV2IOExt] → | size = 100 | avg = 7.39951
(Unnamed Layer* 219) [Slice] → | size = 100 | avg = 0.00574112
(Unnamed Layer* 22) [Convolution] + (Unnamed Layer* 24) [ElementWise] + (Unnamed Layer* 25) [Activation] → | size = 100 | avg = 0.251393
(Unnamed Layer* 220) [Slice] → | size = 100 | avg = 0.00396224
(Unnamed Layer* 222) [Identity] → | size = 100 | avg = 0.00621984
(Unnamed Layer* 223) [Constant] → | size = 100 | avg = 0.00169408
(Unnamed Layer* 224) [Identity] → | size = 100 | avg = 0.00405248
(Unnamed Layer* 227) [Shuffle] + (Unnamed Layer* 228) [Shuffle] → | size = 100 | avg = 0.0091824
(Unnamed Layer* 229) [Shuffle] → | size = 100 | avg = 0.0483347
(Unnamed Layer* 230) [Shuffle] + (Unnamed Layer* 231) [Shuffle] → | size = 100 | avg = 0.0103882
(Unnamed Layer* 232) [Shuffle] → | size = 100 | avg = 0.0459472
(Unnamed Layer* 233) [PluginV2IOExt] → | size = 100 | avg = 528.319
(Unnamed Layer* 234) [Shuffle] → | size = 100 | avg = 33.8268
The last PluginV2IOExt is the RoiAlign layer.
I don’t understand why it is slower in TensorRT. I checked that the same parameters are used.
Also, the layer right after the ROIAlign is suspiciously slow. I, suspect this is some memory access
issue. How can I fix it?
Below is the code for the CUDA kernel in ROI Align that iexactly the same in Pytorch and TensorRT.
#define CUDA_1D_KERNEL_LOOP(i, n)
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
i += blockDim.x * gridDim.x)
template
device T bilinear_interpolate(const T* bottom_data,
const int height, const int width,
T y, T x,
const int index /* index for debug only*/) {
// deal with cases that inverse elements are out of feature map boundary
if (y < -1.0 || y > height || x < -1.0 || x > width) {
//empty
return 0;
}
if (y <= 0) y = 0;
if (x <= 0) x = 0;
int y_low = (int) y;
int x_low = (int) x;
int y_high;
int x_high;
if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = (T) y_low;
} else {
y_high = y_low + 1;
}
if (x_low >= width - 1) {
x_high = x_low = width - 1;
x = (T) x_low;
} else {
x_high = x_low + 1;
}
T ly = y - y_low;
T lx = x - x_low;
T hy = 1. - ly, hx = 1. - lx;
// do bilinear interpolation
T v1 = bottom_data[y_low * width + x_low];
T v2 = bottom_data[y_low * width + x_high];
T v3 = bottom_data[y_high * width + x_low];
T v4 = bottom_data[y_high * width + x_high];
T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
template
global void RoIAlignForward(const int nthreads, const T* bottom_data,
const T spatial_scale, const int channels,
const int height, const int width,
const int pooled_height, const int pooled_width,
const int sampling_ratio,
const T* bottom_rois, T* top_data) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
// (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
const T* offset_bottom_rois = bottom_rois + n * 5;
int roi_batch_ind = offset_bottom_rois[0];
// Do not using rounding; this implementation detail is critical
T roi_start_w = offset_bottom_rois[1] * spatial_scale;
T roi_start_h = offset_bottom_rois[2] * spatial_scale;
T roi_end_w = offset_bottom_rois[3] * spatial_scale;
T roi_end_h = offset_bottom_rois[4] * spatial_scale;
// T roi_start_w = round(offset_bottom_rois[1] * spatial_scale);
// T roi_start_h = round(offset_bottom_rois[2] * spatial_scale);
// T roi_end_w = round(offset_bottom_rois[3] * spatial_scale);
// T roi_end_h = round(offset_bottom_rois[4] * spatial_scale);
// Force malformed ROIs to be 1x1
T roi_width = max(roi_end_w - roi_start_w, (T)1.);
T roi_height = max(roi_end_h - roi_start_h, (T)1.);
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
const T* offset_bottom_data = bottom_data + (roi_batch_ind * channels + c) * height * width;
// We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2
int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
// We do average (integral) pooling inside a bin
const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
T output_val = 0.;
for (int iy = 0; iy < roi_bin_grid_h; iy ++) // e.g., iy = 0, 1
{
const T y = roi_start_h + ph * bin_size_h + static_cast<T>(iy + .5f) * bin_size_h / static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
for (int ix = 0; ix < roi_bin_grid_w; ix ++)
{
const T x = roi_start_w + pw * bin_size_w + static_cast<T>(ix + .5f) * bin_size_w / static_cast<T>(roi_bin_grid_w);
T val = bilinear_interpolate(offset_bottom_data, height, width, y, x, index);
output_val += val;
}
}
output_val /= count;
top_data[index] = output_val;
}
}
int RoiAlign::enqueue(int batchSize, const void *const *inputs, void **outputs, void *workspace, cudaStream_t stream)
{
const int output_size = K * C * output_height * output_width;
dim3 grid(std::min(static_cast<long>(ceil(output_size/512L)), 4096L));
dim3 block(512);
auto feat = static_cast<const float*>(inputs[0]);
auto rois = static_cast<const float*>(inputs[1]);
auto output = static_cast<float*>(outputs[0]);
auto startTime = std::chrono::high_resolution_clock::now();
RoIAlignForward<float><<<grid, block, 0, stream>>>(
output_size, // nthreads
feat, // bottom_data
spatial_scale, // spatial_scale
C, // channels
H, // height
W, // width
output_height, // pooled_height
output_width, // pooled_width
sampling_ratio, // sampling_ratio
rois, // bottom_rois
output); // top_data
cudaStreamSynchronize(stream);
auto endTime = std::chrono::high_resolution_clock::now();
float totalTime = std::chrono::duration<float, std::milli>(endTime - startTime).count();
cout << "\ntotalTime = " << totalTime << "\n";
return 0;
}