How to convert the model with grid_sample to TensorRT with INT8 quantization?

Description

I am trying to convert the model with torch.nn.functional.grid_sample from Pytorch (1.9) to TensorRT (7) with INT8 quantization throught ONNX (opset 11).
Opset 11 does not support grid_sample conversion to ONNX. Thus according to the advice (How to optimize the custom bilinear sampling alternative to grid_sample for TensorRT inference?) I used ONNX graphsurgeon together with the external GridSamplePlugin as it is proposed here (GitHub - TrojanXu/onnxparser-trt-plugin-sample: A sample for onnxparser working with trt user defined plugins for TRT7.0). With it the conversion to TensorRT (both with and without INT8 quantization) is succesfull.
Pytorch and TRT model without INT8 quantization provide results close to identical ones (MSE is of e-10 order). But for TensorRT with INT8 quantization MSE is much higher (185).
grid_sample operator gets two inputs: the input signal and the sampling grid. Both of them should be of the same type. In the GridSamplePlugin only processing of kFLOAT and kHALF is implemented.
In my case X coordinate in the absolute sampling grid (before it is converted to the relative one required for grid_sample) is changing in the range [-d; W+d], and [-d; H+d] for Y coordinate. Maximal value of W is 640, and 360 for H. And the coordinates may have non-integer values in this range.
For the test purposes I created the test model that contains only grid_sample layer. And in this case TensorRT results with and without INT8 quantization are identical.

So the questions are:

  • Is it valid to apply INT8 quantization to functions with at least one indexing input (like grid_sample)? Doesn’t such quantization lead to significant change of the result (if we apply INT8 quantization to the input with the range [0…640) for example)?
  • How INT8 quantization works with the custom plugin, if only FP32 and FP16 are implemented in this plugin code?
  • Is the same result of the test network in TensorRT with and without INT8 quantization obtained due to the fact that the grid_sample input is actually the network input?

Environment

TensorRT Version: 7.2.3.4
GPU Type: NVidia GeForce GTX 1050 Ti
Nvidia Driver Version: 470.63.01
CUDA Version: 10.2.89
CUDNN Version: 8.1.1
Operating System + Version: Ubuntu 18.04
Python Version (if applicable): 3.7
TensorFlow Version (if applicable):
PyTorch Version (if applicable): 1.9
Baremetal or Container (if container which image + tag):

Relevant Files

Here is the code of the test model:

import torch
import numpy as np
import cv2

BATCH_SIZE = 1
WIDTH = 640
HEIGHT = 360

def calculate_grid(B, H, W, dtype, device='cuda'):
    xx = torch.arange(0, W, device=device).view(1, -1).repeat(H, 1).type(dtype)
    yy = torch.arange(0, H, device=device).view(-1, 1).repeat(1, W).type(dtype)
    xx = xx + yy * 0.25
    if B > 1:
        xx = xx.view(1, 1, H, W).repeat(B, 1, 1, 1)
        yy = yy.view(1, 1, H, W).repeat(B, 1, 1, 1)
    else:
        xx = xx.view(1, 1, H, W)
        yy = yy.view(1, 1, H, W)
    vgrid = torch.cat((xx, yy), 1).type(dtype)
    return vgrid.type(dtype)

def modify_grid(vgrid, H, W):
    vgrid = torch.cat([
        torch.sub(2.0 * vgrid[:, :1, :, :].clone() / max(W - 1, 1), 1.0),
        torch.sub(2.0 * vgrid[:, 1:2, :, :].clone() / max(H - 1, 1), 1.0),
        vgrid[:, 2:, :, :]], dim=1)
    vgrid = vgrid.permute(0, 2, 3, 1)
    return vgrid

class GridSamplingBlock(torch.nn.Module):

    def __init__(self):
        super(GridSamplingBlock, self).__init__()

    def forward(self, input, vgrid):
        output = torch.nn.functional.grid_sample(input, vgrid)
        return output

if __name__ == '__main__':
    model = torch.nn.DataParallel(GridSamplingBlock())
    model.cuda()
    print("Reading inputs")
    img = cv2.imread("result/left_frame_rect_0373.png")
    img = cv2.resize(cv2.cvtColor(img, cv2.COLOR_BGR2GRAY), (WIDTH, HEIGHT))
    img_in = torch.from_numpy(img.astype(float)).view(1, 1, HEIGHT, WIDTH).cuda()
    vgrid = calculate_grid(BATCH_SIZE, HEIGHT, WIDTH, img_in.dtype)
    vgrid = modify_grid(vgrid, HEIGHT, WIDTH)
    np.save("result/grid", vgrid.cpu().detach().numpy())
    print("Getting output")
    with torch.no_grad():
        model.module.eval()
        img_out = model.module(img_in, vgrid)
        img = img_out.cpu().detach().numpy().squeeze()
        cv2.imwrite("result/grid_sample_test_output.png", img.astype(np.uint8))

Saved grid is used for both calibration and inference of the TensorRT model.
GridSamplerPlugin git: GitHub - TrojanXu/onnxparser-trt-plugin-sample: A sample for onnxparser working with trt user defined plugins for TRT7.0
TensorRT OSS git: GitHub - NVIDIA/TensorRT: TensorRT is a C++ library for high performance inference on NVIDIA GPUs and deep learning accelerators.
Numpy files reading in C++: GitHub - llohse/libnpy: C++ library for reading and writing of numpy's .npy files

Steps To Reproduce

Hi,
Please refer to the below link for Sample guide.

Refer to the installation steps from the link if in case you are missing on anything

However suggested approach is to use TRT NGC containers to avoid any system dependency related issues.

In order to run python sample, make sure TRT python packages are installed while using NGC container.
/opt/tensorrt/python/python_setup.sh

In case, if you are trying to run custom model, please share your model and script with us, so that we can assist you better.
Thanks!

Thank you!
I already get familiar with these instructions.
Unfortunately, I have no possibility to share the model with custom modifications, I may only say that the initial model for the modifications was original AnyNet. Is there a possibility to proceed with my questions considering such limitation?

Hi,

It still works. It just mean that the plugin will not run in INT8.

Please allow us sometime to back on the other questions.

Thank you.

1 Like

Hi,

grid_sample seems to be a custom plugin, to answer the questions generally:

  • Is it valid to apply INT8 quantization to functions with at least one indexing input (like grid_sample)? Doesn’t such quantization lead to significant change of the result (if we apply INT8 quantization to the input with the range [0…640) for example) - no, it doesn’t make sense to quantize index values.
  • How INT8 quantization works with the custom plugin, if only FP32 and FP16 are implemented in this plugin code? - As mentioned previously, if the plugin does not support INT8 it will use its FP32 / FP16 implementation
  • Is the same result of the test network in TensorRT with and without INT8 quantization obtained due to the fact that the grid_sample input is actually the network input? - Not exactly sure about qn here. if you are getting similar results with and without INT8 quantization it seems that this grid_sample plugin is running in FP32/FP16 in both cases. As long as the rest of the layers are properly quantized, TRT will handle the reformatting

Thank you.

2 Likes

Hi, SergiySnake,

I also used this grid sample plugin in my network. I got stuck on the same issue when I do INT8 quantization. How did you solve this problem?

Best,
Thank you.

This topic was automatically closed 14 days after the last reply. New replies are no longer allowed.