Different FP16 inference with tensorrt and pytorch

I created network with one convolution layer and use same weights for tensorrt and pytorch.
When I use float32 results are almost equal.
But when I use float16 in tensorrt I got float32 in the output and different results.
Tested on Jetson TX2 and Tesla P100.

import torch
from torch import nn
import numpy as np
import tensorrt as trt

import pycuda.driver as cuda
import pycuda.autoinit

TRT_LOGGER = trt.Logger(trt.Logger.WARNING)

class PytorchModel(nn.Module):
    def __init__(self, weights):
        super().__init__()
        self.conv = nn.Conv2d(1, 2, kernel_size=(3, 3), bias=False)
        self.conv.weight.data = torch.Tensor(weights)

    def forward(self, x):
        x = self.conv(x)
        return x

def calc_pytorch(data, weights, use_fp16):
    if use_fp16:
        np_dtype = np.float16
    else:
        np_dtype = np.float32

    data = data.astype(dtype=np_dtype)
    weights = weights.astype(dtype=np_dtype)
    model = PytorchModel(weights)
    model.eval()
    model.to('cuda')
    if use_fp16:
        model.half()

    data = torch.Tensor(data)
    data = data.unsqueeze(dim=0)
    data = data.to('cuda')
    if use_fp16:
        data = data.half()

    with torch.no_grad():
        output = model(data).cpu().numpy()

    output = output.ravel()

    return output

# Simple helper data class that's a little nicer to use than a 2-tuple.
class HostDeviceMem(object):
    def __init__(self, host_mem, device_mem):
        self.host = host_mem
        self.device = device_mem

    def __str__(self):
        return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device)

    def __repr__(self):
        return self.__str__()

# Allocates all buffers required for an engine, i.e. host/device inputs/outputs.
def allocate_buffers(engine):
    inputs = []
    outputs = []
    bindings = []
    stream = cuda.Stream()
    for binding in engine:
        size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size
        dtype = trt.nptype(engine.get_binding_dtype(binding))
        # Allocate host and device buffers
        host_mem = cuda.pagelocked_empty(size, dtype)
        device_mem = cuda.mem_alloc(host_mem.nbytes)
        # Append the device buffer to device bindings.
        bindings.append(int(device_mem))
        # Append to the appropriate list.
        if engine.binding_is_input(binding):
            inputs.append(HostDeviceMem(host_mem, device_mem))
        else:
            outputs.append(HostDeviceMem(host_mem, device_mem))
            print('output engine.get_binding_dtype(binding)', engine.get_binding_dtype(binding))

    return inputs, outputs, bindings, stream

# This function is generalized for multiple inputs/outputs.
# inputs and outputs are expected to be lists of HostDeviceMem objects.
def do_inference(context, bindings, inputs, outputs, stream, batch_size=1):
    # Transfer input data to the GPU.
    [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs]
    # Run inference.
    context.execute_async(batch_size=batch_size, bindings=bindings, stream_handle=stream.handle)
    # Transfer predictions back from the GPU.
    [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs]
    # Synchronize the stream
    stream.synchronize()
    # Return only the host outputs.
    return [out.host for out in outputs]

def GiB(val):
    return val * 1 << 30

def build_engine(weights, use_fp16):
    if use_fp16:
        trt_dtype = trt.float16
    else:
        trt_dtype = trt.float32

    with trt.Builder(TRT_LOGGER) as builder, builder.create_network() as network:

        if use_fp16:
            builder.fp16_mode = True
            builder.strict_type_constraints = True

        print('builder.platform_has_fast_fp16', builder.platform_has_fast_fp16)
        print('builder.fp16_mode', builder.fp16_mode)

        builder.max_workspace_size = GiB(1)

        input_tensor = network.add_input(name='input', dtype=trt_dtype, shape=[1, 3, 3])
        input_tensor.name = 'input'
        print('input_tensor.dtype', input_tensor.dtype)

        conv_w = trt.Weights(weights)
        print('conv_w.dtype', conv_w.dtype)
        conv_b = trt.Weights(type=trt_dtype)
        print('conv_b.dtype', conv_b.dtype)
        conv = network.add_convolution(input=input_tensor, num_output_maps=2,
                                       kernel_shape=(3, 3),
                                       kernel=conv_w, bias=conv_b)

        conv.precision = trt_dtype

        network.mark_output(tensor=conv.get_output(0))

        return builder.build_cuda_engine(network)

def calc_tensorrt(data, weights, use_fp16):
    if use_fp16:
        np_dtype = np.float16
    else:
        np_dtype = np.float32

    data = data.astype(dtype=np_dtype)
    weights = weights.astype(dtype=np_dtype)

    with build_engine(weights, use_fp16) as engine:
        inputs, outputs, bindings, stream = allocate_buffers(engine)
        with engine.create_execution_context() as context:
            np.copyto(inputs[0].host, data.ravel())
            [output_trt] = do_inference(context, bindings=bindings, inputs=inputs,
                                        outputs=outputs, stream=stream)

    return output_trt

def main():
    weights = [[[[0.000001, 0.000002, 0.000003],
                 [0.000004, 0.000005, 0.000006],
                 [7, 8, 9]]],
               [[[9, 8, 7],
                 [6, 5, 4],
                 [30000, 20000, 10000]]]]

    weights = np.array(weights)

    data = [[[0.0001, 0.0002, 0.0003],
             [0.0004, 0.0005, 0.0006],
             [0.0007, 0.0008, 0.0009]]]

    data = np.array(data)

    print('=======Pytorch FP32=======')
    output_pt = calc_pytorch(data, weights, use_fp16=False)
    print('output_pt.dtype', output_pt.dtype)
    print('output_pt {:.16f} {:.16f}'.format(output_pt[0], output_pt[1]))

    print('=======Pytorch FP16=======')
    output_pt = calc_pytorch(data, weights, use_fp16=True)
    print('output_pt.dtype', output_pt.dtype)
    print('output_pt {:.16f} {:.16f}'.format(output_pt[0], output_pt[1]))

    print('=======TensorRT FP32=======')
    output_trt = calc_tensorrt(data, weights, use_fp16=False)
    print('output_trt.dtype', output_trt.dtype)
    print('output_trt {:.16f} {:.16f}'.format(output_trt[0], output_trt[1]))

    print('=======TensorRT FP16=======')
    output_trt = calc_tensorrt(data, weights, use_fp16=True)
    print('output_trt.dtype', output_trt.dtype)
    print('output_trt {:.16f} {:.16f}'.format(output_trt[0], output_trt[1]))

if __name__ == '__main__':
    print('TensorRT version:', trt.__version__)
    main()

Result Tesla P100:

TensorRT version: 5.0.2.6
=======Pytorch FP32=======
output_pt.dtype float32
output_pt 0.0194000080227852 46.0118980407714844
=======Pytorch FP16=======
output_pt.dtype float16
output_pt 0.0193939208984375 46.0000000000000000
=======TensorRT FP32=======
builder.platform_has_fast_fp16 True
builder.fp16_mode False
input_tensor.dtype DataType.FLOAT
conv_w.dtype DataType.FLOAT
conv_b.dtype DataType.FLOAT
output engine.get_binding_dtype(binding) DataType.FLOAT
output_trt.dtype float32
output_trt 0.0194000080227852 46.0118980407714844
=======TensorRT FP16=======
builder.platform_has_fast_fp16 True
builder.fp16_mode True
input_tensor.dtype DataType.HALF
conv_w.dtype DataType.HALF
conv_b.dtype DataType.HALF
output engine.get_binding_dtype(binding) DataType.FLOAT
output_trt.dtype float32
output_trt 0.0193939208984375 46.0312500000000000

Result Jetson TX2:

TensorRT version: 5.0.6.3
=======Pytorch FP32=======
output_pt.dtype float32
output_pt 0.0194000061601400 46.0118942260742188
=======Pytorch FP16=======
output_pt.dtype float16
output_pt 0.0193939208984375 46.0000000000000000
=======TensorRT FP32=======
builder.platform_has_fast_fp16 True
builder.fp16_mode False
input_tensor.dtype DataType.FLOAT
conv_w.dtype DataType.FLOAT
conv_b.dtype DataType.FLOAT
output engine.get_binding_dtype(binding) DataType.FLOAT
output_trt.dtype float32
output_trt 0.0194000080227852 46.0118980407714844
=======TensorRT FP16=======
builder.platform_has_fast_fp16 True
builder.fp16_mode True
input_tensor.dtype DataType.HALF
conv_w.dtype DataType.HALF
conv_b.dtype DataType.HALF
output engine.get_binding_dtype(binding) DataType.FLOAT
output_trt.dtype float32
output_trt 0.0193939208984375 46.0312500000000000

we are reviewing and will keep you updated.

Hello, per engineering “Inputs and outputs for networks can only be FP32 right now. We do not support fp16 output yet. So I believe this is expected.”

It’s a strange answer.
What is the purpose of function set_output_type?
[url]https://docs.nvidia.com/deeplearning/sdk/tensorrt-api/python_api/infer/Graph/LayerBase.html#tensorrt.ILayer.set_output_type[/url]

hi alexei.khatin
this fun set the layer to type you want ,whatever tensorrt engine type you set with builder.build_engine.
it is used for partial quantization

Hi,
Can you try running your model with trtexec command, and share the “”–verbose"" log in case if the issue persist
https://github.com/NVIDIA/TensorRT/tree/master/samples/opensource/trtexec

You can refer below link for all the supported operators list, in case any operator is not supported you need to create a custom plugin to support that operation

Also, request you to share your model and script if not shared already so that we can help you better.

Meanwhile, for some common errors and queries please refer to below link:

Thanks!