How exactly are you supposed to do explicit quantization?

Description

I’ve been looking into doing explicit quantization with TensorRT and I guess there’s a flaw in my logic somewhere because I haven’t been able to get an example working. My understanding was the explicit quantization occurs when the network has Q/DQ layers in it. My natural conclusion was that if I had started with, say, a pytorch model, then I would have to quantize it myself in pytorch (using eager mode quantization, manually placing Q/DQ layers), export it to ONNX, and then convert to a TensorRT engine.

First: is this workflow correct? I can’t seem to find any information on how else I would add Q/DQ layers. I’ve seen this tool but I haven’t had a chance to try it out yet - is it really necessary?

If my workflow is correct then for some reason I haven’t gotten it to work. The code snippet included trains a CNN on MNIST and quantizes it using a few different qconfigs (choosing from per_tensor/channel, and affine/symmetric). Despite this, every single one of these models produces the same error output when I attempt to build a TRT engine:

03-04 02:55:58 | 857653 | E | Global | ModelImporter.cpp:726: While parsing node number 2 [QuantizeLinear -> "/0/QuantizeLinear_output_0"]:
03-04 02:55:58 | 857653 | E | Global | ModelImporter.cpp:727: --- Begin node ---
03-04 02:55:58 | 857653 | E | Global | ModelImporter.cpp:728: input: "input"
input: "/0/Constant_1_output_0"
input: "/0/Constant_output_0"
output: "/0/QuantizeLinear_output_0"
name: "/0/QuantizeLinear"
op_type: "QuantizeLinear"

03-04 02:55:58 | 857653 | E | Global | ModelImporter.cpp:729: --- End node ---
03-04 02:55:58 | 857653 | E | Global | ModelImporter.cpp:732: ERROR: builtin_op_importers.cpp:1235 In function QuantDequantLinearHelper:
[6] Assertion failed: shiftIsAllZeros(zeroPoint) && "TRT only supports symmetric quantization - zeroPt must be all zeros"
03-04 02:55:58 | 857653 | E | Global | Failed to convert ../models/mnist/mnist_quantized_per_tensor_symmetric_one_shot.onnx to ../models/mnist/mnist_quantized_per_tensor_symmetric_one_shot.engine

This was surprising to me because even those models with symmetric qconfigs failed to build. I manually inspected the zero points of all the weights, and they are indeed zero.

So how are you supposed to do explicit quantization? I haven’t found a good sample demonstrating it.

One wrinkle in this is that I’m running the python script on one machine, and attempting to build the engine on another. Could there be a version mismatch? I’ve included both environments’ details for completeness:

Environment

Python Environment

GPU Type: GeForce RTX 3090
Nvidia Driver Version: 560.94
CUDA Version: 12.6
Operating System + Version: Ubuntu 22.04.3 LTS x86_64
Python Version: 3.12.2
PyTorch Version: 2.5.1

TRT Environment

TensorRT Version: 8.5.3
GPU Type: A10G
Nvidia Driver Version: 550.127.08
CUDA Version: 12.4
Operating System + Version: Ubuntu 20.04.6 LTS x86_64

Steps To Reproduce

This script produces some quantized models.

import copy

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader


# Define a simple CNN model
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)  # Output: 32x28x28
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)  # Output: 64x28x28
        self.relu2 = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)  # Output: 64x14x14
        self.fc1 = nn.Linear(64 * 14 * 14, 128)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)  # 10 output classes

    def forward(self, x):
        x = self.relu1(self.conv1(x))
        x = self.pool(self.relu2(self.conv2(x)))
        x = x.reshape(x.size(0), -1)  # Flatten the tensor
        x = self.relu3(self.fc1(x))
        x = self.fc2(x)
        return x


# Training function
def train(model, criterion, optimizer, train_loader, epochs, device):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        correct = 0
        total = 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        print(
            f"Epoch [{epoch+1}/{EPOCHS}], Loss: {total_loss/len(train_loader):.4f}, Accuracy: {100 * correct / total:.2f}%"
        )


# Evaluation function
def evaluate(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f"Test Accuracy: {100 * correct / total:.2f}%")

def print_quantized_weights(model):
    print("\nQuantized Weight Parameters:")
    for name, module in model.named_modules():
        if isinstance(module, torch.ao.nn.quantized.Linear) or isinstance(module, torch.ao.nn.quantized.Conv2d):
            weight = module.weight()
            try:
                scale = module.weight().q_scale()
                zero_point = module.weight().q_zero_point()
            except:
                breakpoint()
                continue
            print(f"Layer: {name}")
            print(f"  Scale: {scale}")
            print(f"  Zero Point: {zero_point}")
            print("")

if __name__ == "__main__":
    # Hyperparameters
    BATCH_SIZE = 64
    LEARNING_RATE = 0.001
    EPOCHS = 5

    # Transformations for normalization
    transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize(
                (0.1307,), (0.3081,)
            ),  # Normalizing with MNIST mean and std
        ]
    )

    # Load MNIST dataset
    train_dataset = torchvision.datasets.MNIST(
        root="./data", train=True, transform=transform, download=True
    )
    test_dataset = torchvision.datasets.MNIST(
        root="./data", train=False, transform=transform, download=True
    )

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

    # Initialize model, loss function, and optimizer
    model = CNN().to("cuda")
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    train(model, criterion, optimizer, train_loader, EPOCHS, "cuda")
    evaluate(model, test_loader, "cuda")

    model = model.to("cpu")

    onnx_filename = f"mnist_model.onnx"
    dummy_input = torch.randn(1, 1, 28, 28)
    torch.onnx.export(
        model, 
        dummy_input, 
        onnx_filename, 
        input_names=['input'], 
        output_names=['output'],
        opset_version=13
    )

    quantization_schemes = ["per_tensor_affine", "per_tensor_symmetric", "per_channel_affine", "per_channel_symmetric"]
    # calibration_schemes = ["random", "one_shot", "full"]
    # quantization_schemes = ["per_channel_affine"]
    calibration_schemes = ["one_shot"]

    for quantization_scheme in quantization_schemes:
        for calibration_scheme in calibration_schemes:
            print(
                f"\nQuantization scheme: {quantization_scheme}, Calibration scheme: {calibration_scheme}"
            )

            model_int8 = copy.deepcopy(model)
            model_int8.eval()

            torch.quantization.fuse_modules(
                model_int8,
                [["conv1", "relu1"], ["conv2", "relu2"], ["fc1", "relu3"]],
                inplace=True,
            )

            model_int8 = nn.Sequential(
                torch.quantization.QuantStub(),
                model_int8,
                torch.quantization.DeQuantStub(),
            )

            # model_int8.qconfig = torch.quantization.get_default_qconfig('fbgemm')
            # Default uses per_channel_affine, which may be doing a much better job of quantizing the weights
            if quantization_scheme == "per_tensor_affine":
                qconfig = torch.quantization.QConfig(
                    activation=torch.ao.quantization.HistogramObserver.with_args(
                        qscheme=torch.per_tensor_affine, reduce_range=True
                    ),
                    weight=torch.ao.quantization.HistogramObserver.with_args(
                        qscheme=torch.per_tensor_affine, dtype=torch.qint8
                    ),
                )
            elif quantization_scheme == "per_tensor_symmetric":
                qconfig = torch.quantization.QConfig(
                    activation=torch.ao.quantization.HistogramObserver.with_args(
                        qscheme=torch.per_tensor_symmetric, reduce_range=True
                    ),
                    weight=torch.ao.quantization.HistogramObserver.with_args(
                        qscheme=torch.per_tensor_symmetric, dtype=torch.qint8
                    ),
                )
            elif quantization_scheme == "per_channel_affine":
                qconfig = torch.quantization.QConfig(
                    activation = torch.ao.quantization.HistogramObserver.with_args(reduce_range = True),
                    weight = torch.ao.quantization.PerChannelMinMaxObserver.with_args(dtype = torch.qint8, qscheme = torch.per_channel_symmetric)
                )

                # qconfig = torch.quantization.get_default_qconfig("fbgemm")
                # breakpoint()
            elif quantization_scheme == "per_channel_symmetric":
                qconfig = torch.quantization.QConfig(
                    activation = torch.ao.quantization.HistogramObserver.with_args(qscheme = torch.per_tensor_symmetric, reduce_range = True),
                    weight = torch.ao.quantization.PerChannelMinMaxObserver.with_args(dtype = torch.qint8, qscheme = torch.per_channel_symmetric)
                )
            model_int8.qconfig = qconfig
            torch.quantization.prepare(model_int8, inplace=True)

            with torch.inference_mode():
                if calibration_scheme == "random":
                    model_int8(torch.randn(1, 1, 28, 28))
                elif calibration_scheme == "one_shot":
                    for images, labels in train_loader:
                        model_int8(images)
                        break
                elif calibration_scheme == "full":
                    for images, labels in train_loader:
                        model_int8(images)

            torch.quantization.convert(model_int8, inplace=True)

            evaluate(model_int8, test_loader, "cpu")
            print_quantized_weights(model_int8)

            # Convert the quantized model to ONNX and serialize it
            onnx_filename = f"mnist_quantized_{quantization_scheme}_{calibration_scheme}.onnx"
            dummy_input = torch.randn(1, 1, 28, 28)
            torch.onnx.export(
                model_int8, 
                dummy_input, 
                onnx_filename, 
                input_names=['input'], 
                output_names=['output'], 
                # dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}},
                opset_version=13
            )
            print(f"Quantized model exported to {onnx_filename}")

I can’t include the script I’m using to build the TRT engine as it’s internal to our org (and I just don’t have the time to strip it to its bare bones right now) but it’s pretty straightforward, and we’ve used it with no issues in the past. I’ll try to include it later. Perhaps there’s a builder config flag I’m missing? I have no clue.

Thanks!

I am currently working on similiar thing, as far as I understand it is better to use PyTorch Quantization — Model Optimizer 0.25.0 for quantization if you plan to combine it with TensorRT later. The pytorch quantization methods are not best tools for such task