Hello everyone,
I am currently working on deploying a 3D Deep Learning model on NVIDIA Jetson. My goal is to maximize performance by running the entire model (or as much as possible) in INT8 precision.
My current workflow is:
-
Training: PyTorch model training.
-
Quantization: QAT (Quantization Aware Training) using the PyTorch quantization toolkit / TensorRT Model Optimizer.
-
Export: Export to ONNX.
-
Deployment: Building the engine with TensorRT 10.13 (only for test).
The Issue: The quantization works perfectly for most of the model (standard 3D Convolutions), but the ConvTranspose3d layers remain in FP16 (or FP32) in the final engine. They do not seem to run in INT8 despite the QAT calibration.
I have searched through GitHub issues and the forum, but I haven’t found concrete documentation confirming if ConvTranspose3d has INT8 kernel implementations on Jetson (Orin/Xavier) architectures.
Context on the layer:
-
This specific part of the model performs a 16x upsample.
-
I am currently using a strided
ConvTranspose3dto achieve this with learnable parameters.
My Questions:
-
Is
ConvTranspose3dsupported in INT8 in TensorRT 10.x? Or is the fallback to FP16 expected behavior due to missing kernels? -
Are there specific constraints (kernel size, stride, padding) required to trigger the INT8 kernel for 3D Deconvolution?
-
Architecture Alternatives: Since I need a high-quality 16x upsample, if
ConvTranspose3dis not hardware-friendly for INT8, would you recommend a different approach? (e.g.,Resize(Nearest/Trilinear) + standardConv3d, orPixelShuffle3D) in int8 mod ?
Here is my part with Devonc3D layers :
class Bottleneck3D(nn.Module):
def __init__(self, c, expansion=2):
super().__init__()
mid_c = c * expansion
self.conv1 = nn.Conv3d(c, mid_c, 1, bias=False)
self.bn1 = nn.BatchNorm3d(mid_c)
self.conv2 = nn.Conv3d(mid_c, mid_c, 3, padding=1, bias=False) # groups=1
self.bn2 = nn.BatchNorm3d(mid_c)
self.conv3 = nn.Conv3d(mid_c, c, 1, bias=False)
self.bn3 = nn.BatchNorm3d(c)
self.act = nn.ReLU6(inplace=True)
self.add = FloatFunctional()
def forward(self, x):
identity = x
out = self.act(self.bn1(self.conv1(x)))
out = self.act(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
out = self.add.add(out, identity)
out = self.act(out)
return out
class Deconv3DStack(nn.Module):
def __init__(self, in_c=64, deconv_layers=4, final_c=32):
super().__init__()
mid_channels = 256
last_channels = 128
self.block1 = nn.Sequential(
nn.ConvTranspose3d(in_c, mid_channels * 2, kernel_size=2, stride=2, bias=False),
nn.BatchNorm3d(mid_channels * 2),
nn.ReLU(inplace=True),
Bottleneck3D(mid_channels * 2),
Bottleneck3D(mid_channels * 2),
)
self.block2 = nn.Sequential(
nn.ConvTranspose3d(mid_channels * 2, mid_channels, kernel_size=2, stride=2, bias=False),
nn.BatchNorm3d(mid_channels),
nn.ReLU(inplace=True),
Bottleneck3D(mid_channels),
)
self.block3 = nn.Sequential(
nn.ConvTranspose3d(mid_channels, last_channels, kernel_size=2, stride=2, bias=False),
nn.BatchNorm3d(last_channels),
nn.ReLU(inplace=True),
)
self.block4 = nn.Sequential(
nn.ConvTranspose3d(last_channels, final_c, kernel_size=2, stride=2, bias=False),
nn.BatchNorm3d(final_c),
nn.ReLU(inplace=True),
)
self.out_c = final_c
def forward(self, x):
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
x = self.block4(x)
return x
Any insights or documentation links would be greatly appreciated!
Thanks.