I am trying to covert the model with torch.nn.functional.grid_sample from Pytorch (1.6) to TensorRT (7) through ONNX (opset 11). Opset 11 does not support grid_sample conversion.
Custom alternative I found (ONNX and grid_sample layer · Issue #27212 · pytorch/pytorch · GitHub) is extremely slow while running in Pytorch and have the problem with converting the main loop to TRT.
My own implementation of bilinear sampling (not just grid_sample, but the whole original sampling, based on grid_sample) performs much faster in Pytorch and is converted to TRT successfully. But my custom bilinear sampling in TRT is slower, than the one in Pytorch (5.6 ms vs 2.0 ms). It turns out, that Pytorch image[:, ind, y0, x0] indexing produce Gather layer with running time about 0.97 ms. And there are 4 such layers in the TRT version of such bilinear sampling.
So the questions are:
- How should I optimize my Pytorch code to get the effective TRT model?
- What should I do to make Gather layer perform faster?
- Can the creation of this function as a custom TRT plugin help with making it faster?
Additionally I tried viewing the image as the linear array over all dimensions except C and creating the linear indexes to adress elements in the form image[:, p0]. And for this case Gather becomes even slower (about 1.07 ms). Then I considered C=1 (as it always is in the original model) and address the tensor elements as image[p0]. This time Gather takes about 0.92 ms (still too slow).
TensorRT Version: 7
GPU Type: NVidia GeForce GTX 1050 Ti
Nvidia Driver Version: 460.73.01
CUDA Version: 10.1
CUDNN Version: 126.96.36.199-1
Operating System + Version: Ubuntu 18.04
Python Version (if applicable): 3.7
TensorFlow Version (if applicable):
PyTorch Version (if applicable): 1.6
Baremetal or Container (if container which image + tag):
Python script for time profiling of bilinear sampling in Pytorch and creating ONNX model:
import time import torch def time_stamp(sync = True): if sync: torch.cuda.synchronize() return time.time() def bilinear_sample_noloop(image, grid): """ Bilinear sampling with no loops. :param image: sampling source of shape [N, C, H, W] :param grid: integer sampling pixel coordinates of shape [N, grid_H, grid_W, 2] :return: sampling result of shape [N, C, grid_H, grid_W] """ Nt, C, H, W = image.shape grid_H = grid.shape grid_W = grid.shape xgrid, ygrid = grid.split([1, 1], dim=-1) mask = ((xgrid >= 0) & (ygrid >= 0) & (xgrid < W - 1) & (ygrid < H - 1)).float() x0 = torch.floor(xgrid) x1 = x0 + 1 y0 = torch.floor(ygrid) y1 = y0 + 1 wa = ((x1 - xgrid) * (y1 - ygrid)).permute(3, 0, 1, 2) wb = ((x1 - xgrid) * (ygrid - y0)).permute(3, 0, 1, 2) wc = ((xgrid - x0) * (y1 - ygrid)).permute(3, 0, 1, 2) wd = ((xgrid - x0) * (ygrid - y0)).permute(3, 0, 1, 2) x0 = (x0 * mask).view(Nt, grid_H, grid_W).long() y0 = (y0 * mask).view(Nt, grid_H, grid_W).long() x1 = (x1 * mask).view(Nt, grid_H, grid_W).long() y1 = (y1 * mask).view(Nt, grid_H, grid_W).long() ind = torch.arange(Nt, device=image.device) #torch.linspace(0, Nt - 1, Nt, device=image.device) ind = ind.view(Nt, 1).expand(-1, grid_H).view(Nt, grid_H, 1).expand(-1, -1, grid_W).long() image = image.permute(1, 0, 2, 3) output_tensor = (image[:, ind, y0, x0] * wa + image[:, ind, y1, x0] * wb + image[:, ind, y0, x1] * wc + \ image[:, ind, y1, x1] * wd).permute(1, 0, 2, 3) output_tensor *= mask.permute(0, 3, 1, 2).expand(-1, C, -1, -1) image = image.permute(1, 0, 2, 3) return output_tensor, mask class BilinearSamplingBlock(torch.nn.Module): def __init__(self): super(BilinearSamplingBlock, self).__init__() def forward(self, inputs): image, grid = inputs samples, mask = bilinear_sample_noloop(image, grid) return [samples, mask] if __name__ == '__main__': model = torch.nn.DataParallel(BilinearSamplingBlock()) model.cuda() dummy_image = torch.ones(3600, 1, 45, 80, device='cuda').float() dummy_grid = torch.ones(3600, 9, 9, 2, device='cuda').float() dummy_inputs = [dummy_image, dummy_grid] with torch.no_grad(): model.module.eval() smax = 100 stime = 0 for i in range(smax): t1 = time_stamp() ex_out = model.module(dummy_inputs) t2 = time_stamp() stime += t2 - t1 print("Mean model time = ", 1000 * stime / smax, " ms") torch.onnx.export(model.module, dummy_inputs, "sampling_block.onnx", opset_version=11, example_outputs=ex_out, verbose=True)
- Run bilinear sampling python script to get Pytorch time and ONNX model;
- Run ‘trtexec --onnx=sampling_block.onnx --workspace=128 --dumpProfile’
- Compare Pytorch time to ONNX time, look for “Gather_NNN” layers in TRT model profiling.
A part of TRT model profiling with trtexec:
Layer Time (ms) Avg. Time (ms) Time %
Mul_146 5.82 0.03 0.5
Add_147 8.50 0.04 0.7
Gather_148 214.39 0.97 17.3
Gather_174 214.25 0.97 17.3
Gather_201 213.88 0.97 17.3
Gather_228 214.48 0.97 17.3
Add_237)) 25.01 0.11 2.0
Mul_251 7.84 0.04 0.6
Total 1238.40 5.60 100.0