Description
A very simple model (in addition to the linked repo, I’ll include the model code at the bottom of this description) takes an extremely long time to create a saved tensorrt engine from and most minor additions to it cause it to OOM after hours of waiting despite having 64g of memory available.
As an example of a change that causes an OOM, try changing torch.where(untraversable, 0, max_trav_sq)
to torch.where(untraversable > 0, 0, max_trav_sq)
Environment
TensorRT Version: 8.6.1
GPU Type: NVIDIA GeForce RTX 3060
Nvidia Driver Version: 550.54.14
CUDA Version: 12.4
CUDNN Version: 8.9.7
Operating System + Version: Debian 11
Python Version (if applicable): 3.9.2
PyTorch Version (if applicable): 2.2.1+cu121
(I’ve also reproduced this issue using trtexec on jetson nano with TensorRT 8.5 and in both the cuda:11.7.0-cudnn8-devel-ubuntu20.04 and cuda:11.8.0-cudnn8-devel-ubuntu20.04 docker containers with TensorRT 8.6.1 and PyTorch 2.0.0 and 2.2.2 respectively)
Relevant Files
Steps To Reproduce
Just run
python convert.py
Note that to regenerate the onnx file, you can run:
python trav.py save
Model Code
import sys
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
import torch.onnx
MAX_TRAV = 21
class Traversability(nn.Module):
def forward(self, untraversable: torch.Tensor) -> torch.Tensor:
max_trav_sq = torch.tensor(MAX_TRAV**2, dtype=torch.int32)
sq_distances = torch.where(untraversable, 0, max_trav_sq)
extra_col = torch.full(
(sq_distances.shape[0], sq_distances.shape[1], 1), max_trav_sq
)
for d in range(1, MAX_TRAV * 2 + 1, 2):
sq_distances = torch.minimum(
sq_distances,
torch.minimum(
torch.cat([sq_distances[:, :, 1:] + d, extra_col], 2),
torch.cat([extra_col, sq_distances[:, :, :-1] + d], 2),
),
)
extra_row = torch.full(
(sq_distances.shape[0], 1, sq_distances.shape[2]), max_trav_sq
)
for d in range(1, MAX_TRAV * 2 + 1, 2):
sq_distances = torch.minimum(
sq_distances,
torch.minimum(
torch.cat([sq_distances[:, 1:, :] + d, extra_row], 1),
torch.cat([extra_row, sq_distances[:, :-1, :] + d], 1),
),
)
return sq_distances.sqrt()