How to optimize the custom bilinear sampling alternative to grid_sample for TensorRT inference?


I am trying to covert the model with torch.nn.functional.grid_sample from Pytorch (1.6) to TensorRT (7) through ONNX (opset 11). Opset 11 does not support grid_sample conversion.
Custom alternative I found (ONNX and grid_sample layer · Issue #27212 · pytorch/pytorch · GitHub) is extremely slow while running in Pytorch and have the problem with converting the main loop to TRT.

My own implementation of bilinear sampling (not just grid_sample, but the whole original sampling, based on grid_sample) performs much faster in Pytorch and is converted to TRT successfully. But my custom bilinear sampling in TRT is slower, than the one in Pytorch (5.6 ms vs 2.0 ms). It turns out, that Pytorch image[:, ind, y0, x0] indexing produce Gather layer with running time about 0.97 ms. And there are 4 such layers in the TRT version of such bilinear sampling.

So the questions are:

  • How should I optimize my Pytorch code to get the effective TRT model?
  • What should I do to make Gather layer perform faster?
  • Can the creation of this function as a custom TRT plugin help with making it faster?

Additionally I tried viewing the image as the linear array over all dimensions except C and creating the linear indexes to adress elements in the form image[:, p0]. And for this case Gather becomes even slower (about 1.07 ms). Then I considered C=1 (as it always is in the original model) and address the tensor elements as image[p0]. This time Gather takes about 0.92 ms (still too slow).


TensorRT Version: 7
GPU Type: NVidia GeForce GTX 1050 Ti
Nvidia Driver Version: 460.73.01
CUDA Version: 10.1
CUDNN Version:
Operating System + Version: Ubuntu 18.04
Python Version (if applicable): 3.7
TensorFlow Version (if applicable):
PyTorch Version (if applicable): 1.6
Baremetal or Container (if container which image + tag):

Relevant Files

Python script for time profiling of bilinear sampling in Pytorch and creating ONNX model:

import time
import torch

def time_stamp(sync = True):
    if sync:
    return time.time()

def bilinear_sample_noloop(image, grid):
    Bilinear sampling with no loops.
    :param image: sampling source of shape [N, C, H, W]
    :param grid: integer sampling pixel coordinates of shape [N, grid_H, grid_W, 2]
    :return: sampling result of shape [N, C, grid_H, grid_W]
    Nt, C, H, W = image.shape
    grid_H = grid.shape[1]
    grid_W = grid.shape[2]
    xgrid, ygrid = grid.split([1, 1], dim=-1)
    mask = ((xgrid >= 0) & (ygrid >= 0) & (xgrid < W - 1) & (ygrid < H - 1)).float()
    x0 = torch.floor(xgrid)
    x1 = x0 + 1
    y0 = torch.floor(ygrid)
    y1 = y0 + 1
    wa = ((x1 - xgrid) * (y1 - ygrid)).permute(3, 0, 1, 2)
    wb = ((x1 - xgrid) * (ygrid - y0)).permute(3, 0, 1, 2)
    wc = ((xgrid - x0) * (y1 - ygrid)).permute(3, 0, 1, 2)
    wd = ((xgrid - x0) * (ygrid - y0)).permute(3, 0, 1, 2)
    x0 = (x0 * mask).view(Nt, grid_H, grid_W).long()
    y0 = (y0 * mask).view(Nt, grid_H, grid_W).long()
    x1 = (x1 * mask).view(Nt, grid_H, grid_W).long()
    y1 = (y1 * mask).view(Nt, grid_H, grid_W).long()
    ind = torch.arange(Nt, device=image.device) #torch.linspace(0, Nt - 1, Nt, device=image.device)
    ind = ind.view(Nt, 1).expand(-1, grid_H).view(Nt, grid_H, 1).expand(-1, -1, grid_W).long()
    image = image.permute(1, 0, 2, 3)
    output_tensor = (image[:, ind, y0, x0] * wa + image[:, ind, y1, x0] * wb + image[:, ind, y0, x1] * wc + \
                     image[:, ind, y1, x1] * wd).permute(1, 0, 2, 3)
    output_tensor *= mask.permute(0, 3, 1, 2).expand(-1, C, -1, -1)
    image = image.permute(1, 0, 2, 3)
    return output_tensor, mask

class BilinearSamplingBlock(torch.nn.Module):
    def __init__(self):
        super(BilinearSamplingBlock, self).__init__()

    def forward(self, inputs):
        image, grid = inputs
        samples, mask = bilinear_sample_noloop(image, grid)
        return [samples, mask]

if __name__ == '__main__':
    model = torch.nn.DataParallel(BilinearSamplingBlock())
    dummy_image = torch.ones(3600, 1, 45, 80, device='cuda').float()
    dummy_grid = torch.ones(3600, 9, 9, 2, device='cuda').float()
    dummy_inputs = [dummy_image, dummy_grid]
    with torch.no_grad():
        smax = 100
        stime = 0
        for i in range(smax):
            t1 = time_stamp()
            ex_out = model.module(dummy_inputs)
            t2 = time_stamp()
            stime += t2 - t1
        print("Mean model time = ", 1000 * stime / smax, " ms")
        torch.onnx.export(model.module, dummy_inputs, "sampling_block.onnx", opset_version=11, example_outputs=ex_out, verbose=True)

Steps To Reproduce

  • Run bilinear sampling python script to get Pytorch time and ONNX model;
  • Run ‘trtexec --onnx=sampling_block.onnx --workspace=128 --dumpProfile’
  • Compare Pytorch time to ONNX time, look for “Gather_NNN” layers in TRT model profiling.

A part of TRT model profiling with trtexec:
Layer Time (ms) Avg. Time (ms) Time %

Mul_146 5.82 0.03 0.5
Add_147 8.50 0.04 0.7
Gather_148 214.39 0.97 17.3
Gather_174 214.25 0.97 17.3
Gather_201 213.88 0.97 17.3
Gather_228 214.48 0.97 17.3
Add_237)) 25.01 0.11 2.0
Mul_251 7.84 0.04 0.6
Total 1238.40 5.60 100.0

Please refer to the installation steps from the below link if in case you are missing on anything
However suggested approach is to use TRT NGC containers to avoid any system dependency related issues.

In order to run python sample, make sure TRT python packages are installed while using NGC container.

Thank you for the response.
For running TRT inference I used docker container
I created TRT engine from ONNX with trtexec and then profile inference time in 3 ways:

  • With trtexec;
  • With custom C++ inference code;
  • With custom Python inference code.

In all the cases I have got close results (Python was a bit slower, around 6.3 ms). Detailed profiling of layers was obtained with trtexec.

Hi @SergiySnake,

Sorry for the delayed response. IMO plugin would be the best approach. Looks like someone implemented one here,

Thank you.

1 Like

Thank you! It seems the plugin and dummy ONNX operation solved my problem at least partially. Though I noticed a specific of this code. ONNX graphsurgeon is used to change dummy ONNX operation to the corresponding plugin reference. While this operation, the buffer size is calculated. To determine the buffer size, the shape of input tensors is used. This shape is known while converting from Pytorch to ONNX, but it is not saved in ONNX file. The only known shape after saving to ONNX is the shape of model inputs and outputs, no shapes for intermediate layers. Thus if GridSample layer is the first one (or a single one, like in the test example), everything is Ok. But in the case of more complex models the error occurs (no possibility of determining shape for None object). I used pre-determined shapes for buffer size calculation. But it is not good solution, as GridSample is applied for tensors of different shape, and I have to use the maximal one.