Extremely small (PyTorch) onnx model takes >12mins to save an engine and most minor changes/additions will OOM

Description

A very simple model (in addition to the linked repo, I’ll include the model code at the bottom of this description) takes an extremely long time to create a saved tensorrt engine from and most minor additions to it cause it to OOM after hours of waiting despite having 64g of memory available.

As an example of a change that causes an OOM, try changing torch.where(untraversable, 0, max_trav_sq) to torch.where(untraversable > 0, 0, max_trav_sq)

Environment

TensorRT Version: 8.6.1
GPU Type: NVIDIA GeForce RTX 3060
Nvidia Driver Version: 550.54.14
CUDA Version: 12.4
CUDNN Version: 8.9.7
Operating System + Version: Debian 11
Python Version (if applicable): 3.9.2
PyTorch Version (if applicable): 2.2.1+cu121

(I’ve also reproduced this issue using trtexec on jetson nano with TensorRT 8.5 and in both the cuda:11.7.0-cudnn8-devel-ubuntu20.04 and cuda:11.8.0-cudnn8-devel-ubuntu20.04 docker containers with TensorRT 8.6.1 and PyTorch 2.0.0 and 2.2.2 respectively)

Relevant Files

Steps To Reproduce

Just run
python convert.py

Note that to regenerate the onnx file, you can run:
python trav.py save

Model Code

import sys
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
import torch.onnx


MAX_TRAV = 21


class Traversability(nn.Module):
    def forward(self, untraversable: torch.Tensor) -> torch.Tensor:
        max_trav_sq = torch.tensor(MAX_TRAV**2, dtype=torch.int32)
        sq_distances = torch.where(untraversable, 0, max_trav_sq)
        extra_col = torch.full(
            (sq_distances.shape[0], sq_distances.shape[1], 1), max_trav_sq
        )
        for d in range(1, MAX_TRAV * 2 + 1, 2):
            sq_distances = torch.minimum(
                sq_distances,
                torch.minimum(
                    torch.cat([sq_distances[:, :, 1:] + d, extra_col], 2),
                    torch.cat([extra_col, sq_distances[:, :, :-1] + d], 2),
                ),
            )
        extra_row = torch.full(
            (sq_distances.shape[0], 1, sq_distances.shape[2]), max_trav_sq
        )
        for d in range(1, MAX_TRAV * 2 + 1, 2):
            sq_distances = torch.minimum(
                sq_distances,
                torch.minimum(
                    torch.cat([sq_distances[:, 1:, :] + d, extra_row], 1),
                    torch.cat([extra_row, sq_distances[:, :-1, :] + d], 1),
                ),
            )
        return sq_distances.sqrt()