Description
Trtexec segfaults while building an onnx engine. The onnx engine is parsed fine (perfoming inference with this model with onnx runtime is fine), but during the optimization process (PointWiseFusion) trtexec segfaults. If I comment out module 4&5 or 6&7 of MutualFeatureScreening, the engine can compile and run correctly. So something to do with the 5->6 and 4->7 transition is causing problems. Each module individually can be exported correctly as well. I’ve tried differnet opsets, constant folding and “simplification” with onnxsim, but that doesn’t really do much.
[05/05/2025-15:57:51] [V] [TRT] PointWiseFusion: Fusing PWN(/6/Add) with PWN(/6/Sigmoid)
[05/05/2025-15:57:51] [V] [TRT] Running: PointWiseFusion on PWN(PWN(/6/Add), PWN(/6/Sigmoid))
[05/05/2025-15:57:51] [V] [TRT] PointWiseFusion: Fusing PWN(PWN(/6/Add), PWN(/6/Sigmoid)) with PWN(PWN(PWN(PWN(/5/Add), PWN(/5/Sigmoid)), PWN(/5/Mul)), PWN(/6/Mul))
Segmentation fault (core dumped)
Note that is is only a small part of the actual computer vision model I am trying to convert, I have isolated the problem to this module.
Environment
TensorRT Version: 10.0.1
GPU Type: NVIDIA RTX 5000 Ada Generation
Nvidia Driver Version: 570.86.10
CUDA Version: 12.8
CUDNN Version: 8.2.4.15
Operating System + Version: Ubuntu 22.04
Python Version (if applicable):
TensorFlow Version (if applicable):
PyTorch Version (if applicable): 2.6.0+cu12.6
Baremetal or Container (if container which image + tag): baremetal
Relevant Files
import torch
from torch import nn, Tensor
from torch.nn import functional as F
def get_activation(act: str | None | nn.Module, inplace: bool = True):
"""get activation function"""
if act is None:
return nn.Identity()
if isinstance(act, nn.Module):
return act
act = act.lower()
m: nn.Module = {
"silu": nn.SiLU,
"swish": nn.SiLU,
"relu": nn.ReLU,
"leaky_relu": nn.LeakyReLU,
"gelu": nn.GELU,
"hardsigmoid": nn.Hardsigmoid,
}[act]()
if hasattr(m, "inplace"):
m.inplace = inplace
return m
class ConvNormLayer(nn.Module):
def __init__(
self, ch_in, ch_out, kernel_size, stride, padding=None, bias=False, act=None
):
super().__init__()
self.conv = nn.Conv2d(
ch_in,
ch_out,
kernel_size,
stride,
padding=(kernel_size - 1) // 2 if padding is None else padding,
bias=bias,
)
self.norm = nn.BatchNorm2d(ch_out)
self.act = nn.Identity() if act is None else get_activation(act)
def forward(self, x):
return self.act(self.norm(self.conv(x)))
class RepVggBlock(nn.Module):
def __init__(self, ch_in, ch_out, act="relu"):
super().__init__()
self.ch_in = ch_in
self.ch_out = ch_out
self.conv1 = ConvNormLayer(ch_in, ch_out, 3, 1, padding=1, act=None)
self.conv2 = ConvNormLayer(ch_in, ch_out, 1, 1, padding=0, act=None)
self.act = nn.Identity() if act is None else get_activation(act)
def forward(self, x):
if hasattr(self, "conv"):
y = self.conv(x)
else:
y = self.conv1(x) + self.conv2(x)
return self.act(y)
class CSPRepLayer(nn.Module):
def __init__(
self,
in_channels,
out_channels,
num_blocks=3,
expansion=1.0,
bias=False,
act="silu",
):
super().__init__()
hidden_ch = int(out_channels * expansion)
self.conv1 = ConvNormLayer(in_channels, hidden_ch, 1, 1, bias=bias, act=act)
self.conv2 = ConvNormLayer(in_channels, hidden_ch, 1, 1, bias=bias, act=act)
self.bottlenecks = nn.Sequential(
*[RepVggBlock(hidden_ch, hidden_ch, act=act) for _ in range(num_blocks)]
)
if hidden_ch != out_channels:
self.conv3 = ConvNormLayer(
hidden_ch, out_channels, 1, 1, bias=bias, act=act
)
else:
self.conv3 = nn.Identity()
def forward(self, x):
x_1 = self.conv1(x)
x_1 = self.bottlenecks(x_1)
x_2 = self.conv2(x)
return self.conv3(x_1 + x_2)
def channel_attention(x: Tensor):
avg_ = F.adaptive_avg_pool2d(x, (1, 1))
max_ = F.adaptive_max_pool2d(x, (1, 1))
attn = F.sigmoid(avg_ + max_)
return attn
def feature_selection(x: Tensor):
return channel_attention(x) * x
class FeatureSelection(nn.Module):
def forward(self, x: Tensor):
return feature_selection(x)
class ScreeningFeatureFusion(nn.Module):
def __init__(self, dim: int, upsample: bool, num_scale: int = 1):
super().__init__()
module = nn.ConvTranspose2d if upsample else nn.Conv2d
kwargs = {
"in_channels": dim,
"out_channels": dim,
"kernel_size": 3,
"padding": 1,
"stride": 2,
"bias": True,
}
if upsample:
kwargs["output_padding"] = 1
if num_scale == 1:
self.conv = module(**kwargs)
else:
self.conv = nn.Sequential(*(module(**kwargs) for _ in range(num_scale)))
def forward(self, f_in: Tensor, f_weight: Tensor):
f_weight = self.conv(f_weight)
f_in = f_in * channel_attention(f_weight)
return f_in + f_weight
class MutualFeatureScreening(nn.Sequential):
def __init__(
self,
dim: int,
expansion: float = 1.0,
depth_scale: float = 1.0,
act: str = "silu",
):
csp_args = (dim, dim, round(3 * depth_scale), expansion, False, act)
super().__init__(
ScreeningFeatureFusion(dim, upsample=True),
ScreeningFeatureFusion(dim, upsample=True, num_scale=2),
CSPRepLayer(*csp_args),
CSPRepLayer(*csp_args),
FeatureSelection(),
FeatureSelection(),
ScreeningFeatureFusion(dim, upsample=False),
ScreeningFeatureFusion(dim, upsample=False, num_scale=2),
CSPRepLayer(*csp_args),
CSPRepLayer(*csp_args),
)
def forward(self, feats: list[Tensor]):
s3, s4, s5 = feats
assert s3.shape[-1] > s4.shape[-1] > s5.shape[-1], (
"Shapes are in wrong order, expected s3 as "
"highest resolution and s5 as lowest"
)
# Block 1
s4 = self[0](s4, s5)
s3 = self[1](s3, s5)
s4 = self[2](s4)
s3 = self[3](s3)
# Intermediate
s5 = self[4](s5)
s4 = self[5](s4)
# Block 2
s4 = self[6](s4, s3)
s5 = self[7](s5, s3)
s5 = self[8](s5)
s4 = self[9](s4)
return [s3, s4, s5]
module = MutualFeatureScreening(256, expansion=0.5)
inputs = [torch.randn(1, 256, x, x) for x in [80, 40, 20]]
# module = ScreeningFeatureFusion(256, upsample=False)
# inputs = tuple(torch.randn(1, 256, x, x) for x in [20, 40])
# module = ScreeningFeatureFusion(256, upsample=True)
# inputs = tuple(torch.randn(1, 256, x, x) for x in [40, 20])
torch.onnx.export(
module, inputs, "mwe.onnx", opset_version=17, do_constant_folding=True
)
Steps To Reproduce
Run the above python code + trtexec on the created onnx file. You can comment out (self[4] AND self[5]) OR (self[6] AND self[7]) and trtexec should work fine.