[TensorRT] OutOfMemory Error when building engine from ONNX model

I’m currently attempting to convert an ONNX model originally exported based on this PyTorch I3D model. I exported this model using PyTorch 1.2.0 which seemed to have been successful. However, when use TensorRT 7.0.0.11 to build a cuda engine for accelerated inference I receive the following error:

[TensorRT] ERROR: Internal error: could not find any implementation for node (Unnamed Layer* 11) [Convolution] + (Unnamed Layer* 13) [Activation] || (Unnamed Layer* 17) [Convolution] + (Unnamed Layer* 19) [Activation], try increasing the workspace size with IBuilder::setMaxWorkspaceSize()
[TensorRT] ERROR: ../builder/tacticOptimizer.cpp (1523) - OutOfMemory Error in computeCosts: 0

The following is the Python 3.7 code I’m using to build the engine. Note that common is the common.py file from the TensorRT 7.0.0.11 samples/python directory.

import numpy as np
import common
import pycuda.driver as cuda
import pycuda.autoinit
import tensorrt as trt

TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE)


def main():
    print('TensorRT Version:', trt.__version__)

    onnx_filename = 'model.onnx'

    def build_engine_onnx(model_file):
        with trt.Builder(TRT_LOGGER) as builder, \
                builder.create_network(common.EXPLICIT_BATCH) as network, \
                trt.OnnxParser(network, TRT_LOGGER) as parser:
            builder.max_workspace_size = int((1 << 32) * 1.61)
            builder.max_batch_size = 1
            with open(model_file, 'rb') as model:
                parser.parse(model.read())
            return builder.build_cuda_engine(network)

    with build_engine_onnx(onnx_filename) as engine:
        # failure occurs before reaching this point
        pass

if __name__ == "__main__":
    main()

I have set the builder.max_workspace_size to the largest I can for my GPU (2060 SUPER 8GB). Monitoring nvidia-smi I’m able to see that my GPU memory maxes out for a few seconds before finally dropping back to zero when the engine builder fails.

This model fits comfortably on my GPU when using PyTorch inference with a batch size of 5+ so it seems very strange that I wouldn’t be able to use this model with TensorRT with even a batch size of 1.

Is there something improper that I’m doing in the code which is causing TensorRT to use excess memory? Are there any workarounds or settings I can use to use less GPU memory? I’ve also tried using the builder.fp16_mode = True which seems to allow the engine creation to proceed further along but it is still only able to process about 20% of the layers in the model before running out of memory.

Hi,
Can you try few things:

  1. Check ONNX model using checker function and see if it passes?
    import onnx

    model = onnx.load(“model.onnx”)

    onnx.checker.check_model(model)

  2. If step 1 pass, try running ONNX model and check the memory consumption

  3. Please try trtexec commands to generate TRT model
    https://github.com/NVIDIA/TensorRT/blob/master/samples/opensource/trtexec/README.md

If issue persist, could you please share the ONNX model so we can better help.

Thanks

1 Like

Thank you for the response Sunil,

The model.onnx file can be downloaded here: https://obj.umiacs.umd.edu/trt-example/model.onnx

I tried what you described. The model passes the onnx checker. Afterwards I was successfully able to run the onnx-runtime as follows:

import onnx
import onnxruntime as rt
import numpy as np

def main():
    print('ONNX version:', onnx.__version__)
    print('ONNX-runtime version:', rt.__version__)

    onnx_filename = 'model.onnx'

    sess = rt.InferenceSession(onnx_filename)
    input_names = [i.name for i in sess.get_inputs()]
    output_names = [o.name for o in sess.get_outputs()]
    print('Input names:', input_names)
    print('Output names:', output_names)

    np.random.seed(1234)
    rand_samp = np.random.randn(1, 2, 32, 144, 144).astype(np.float32)

    preds = sess.run(output_names, {input_names[0]: rand_samp})

    print(preds[0])

if __name__ == "__main__":
    main()

The code above takes up less than 2 GB of RAM and prints the following which matches my output from PyTorch.

ONNX version: 1.6.0
ONNX-runtime version: 1.2.0
Input names: ['x']
Output names: ['y_cls', 'y_loc', 'y_cls_track', 'y_loc_track', 'y_cmp_track']
[[ 0.03988145 -0.09616947 -0.8033749  -1.2462692  -0.45356765 -1.4268675
  -0.65312874 -0.14997464  0.33097583 -0.06500828  0.15914007 -0.51482236
   0.1772879   0.51031244 -0.03300643  0.21024594 -0.20662525 -0.21396095
  -0.73346174 -0.13845211 -0.21005228  0.03861606  0.56694525  0.54108495
   0.12292077 -0.57820493  0.33151776 -0.54985714  0.74340343  0.6315851
   1.0283502  -0.64480615  0.19452536 -0.04577645 -0.11097859 -0.7217783
   0.73407394  0.8563524 ]]

Based on your suggestion I tried converting this model to TensorRT using the trtexec tool

$ trtexec --onnx=model.onnx --shapes=input:1x2x32x144x144 --maxBatch=1 --workspace=6750 --verbose

which results in

&&&& RUNNING TensorRT.trtexec # trtexec --onnx=model.onnx --shapes=input:1x2x32x144x144 --maxBatch=1 --workspace=6750
[05/14/2020-12:35:56] [I] === Model Options ===
[05/14/2020-12:35:56] [I] Format: ONNX
[05/14/2020-12:35:56] [I] Model: model.onnx
[05/14/2020-12:35:56] [I] Output:
[05/14/2020-12:35:56] [I] === Build Options ===
[05/14/2020-12:35:56] [I] Max batch: explicit
[05/14/2020-12:35:56] [I] Workspace: 6750 MB
[05/14/2020-12:35:56] [I] minTiming: 1
[05/14/2020-12:35:56] [I] avgTiming: 8
[05/14/2020-12:35:56] [I] Precision: FP32
[05/14/2020-12:35:56] [I] Calibration: 
[05/14/2020-12:35:56] [I] Safe mode: Disabled
[05/14/2020-12:35:56] [I] Save engine: 
[05/14/2020-12:35:56] [I] Load engine: 
[05/14/2020-12:35:56] [I] Inputs format: fp32:CHW
[05/14/2020-12:35:56] [I] Outputs format: fp32:CHW
[05/14/2020-12:35:56] [I] Input build shape: input=1x2x32x144x144+1x2x32x144x144+1x2x32x144x144
[05/14/2020-12:35:56] [I] === System Options ===
[05/14/2020-12:35:56] [I] Device: 0
[05/14/2020-12:35:56] [I] DLACore: 
[05/14/2020-12:35:56] [I] Plugins:
[05/14/2020-12:35:56] [I] === Inference Options ===
[05/14/2020-12:35:56] [I] Batch: Explicit
[05/14/2020-12:35:56] [I] Iterations: 10
[05/14/2020-12:35:56] [I] Duration: 3s (+ 200ms warm up)
[05/14/2020-12:35:56] [I] Sleep time: 0ms
[05/14/2020-12:35:56] [I] Streams: 1
[05/14/2020-12:35:56] [I] ExposeDMA: Disabled
[05/14/2020-12:35:56] [I] Spin-wait: Disabled
[05/14/2020-12:35:56] [I] Multithreading: Disabled
[05/14/2020-12:35:56] [I] CUDA Graph: Disabled
[05/14/2020-12:35:56] [I] Skip inference: Disabled
[05/14/2020-12:35:56] [I] Inputs:
[05/14/2020-12:35:56] [I] === Reporting Options ===
[05/14/2020-12:35:56] [I] Verbose: Disabled
[05/14/2020-12:35:56] [I] Averages: 10 inferences
[05/14/2020-12:35:56] [I] Percentile: 99
[05/14/2020-12:35:56] [I] Dump output: Disabled
[05/14/2020-12:35:56] [I] Profile: Disabled
[05/14/2020-12:35:56] [I] Export timing to JSON file: 
[05/14/2020-12:35:56] [I] Export output to JSON file: 
[05/14/2020-12:35:56] [I] Export profile to JSON file: 
[05/14/2020-12:35:56] [I] 
----------------------------------------------------------------
Input filename:   model.onnx
ONNX IR version:  0.0.4
Opset version:    10
Producer name:    pytorch
Producer version: 1.2
Domain:           
Model version:    0
Doc string:       
----------------------------------------------------------------
[05/14/2020-12:35:56] [W] [TRT] Setting layouts of network and plugin input/output tensors to linear, as 3D operators are found and 3D non-linear IO formats are not supported, yet.
[05/14/2020-12:35:58] [E] [TRT] Internal error: could not find any implementation for node (Unnamed Layer* 14) [Convolution] + (Unnamed Layer* 16) [Activation] || (Unnamed Layer* 20) [Convolution] + (Unnamed Layer* 22) [Activation], try increasing the workspace size with IBuilder::setMaxWorkspaceSize()
[05/14/2020-12:35:58] [E] [TRT] ../builder/tacticOptimizer.cpp (1523) - OutOfMemory Error in computeCosts: 0
[05/14/2020-12:35:58] [E] Engine creation failed
[05/14/2020-12:35:58] [E] Engine set up failed
&&&& FAILED TensorRT.trtexec # trtexec --onnx=model.onnx --shapes=input:1x2x32x144x144 --maxBatch=1 --workspace=6750

If you have a chance to take a look at what could be causing this issue I would be very grateful!

Issue seems to be due to limited 3D support in current TRT version. May be future release might have better 3D support. Please stay tuned for updates on TRT website.
For now, you can try create a custom plugin for unsupported 3D layers.

Thanks

I was able to track down the cause of this error. It turns out that TensorRT doesn’t like it when there are two or more parallel branches with two or more 3D convolution layers in them.

For example the following PyTorch module exports fine to ONNX but fails to construct an engine in TensorRT 7.0.0.11 with an Internal error: could not find any implementation for node and ../builder/tacticOptimizer.cpp (1523) - OutOfMemory Error in computeCosts: 0

class Broken(nn.Module):
    def __init__(self):
        super().__init__()
        self.branch1 = nn.Sequential(
            nn.Conv3d(1, 1, kernel_size=1, stride=1, bias=False),
            nn.Conv3d(1, 1, kernel_size=1, stride=1, bias=False),
        )

        self.branch2 = nn.Sequential(
            nn.Conv3d(1, 1, kernel_size=1, stride=1, bias=False),
            nn.Conv3d(1, 1, kernel_size=1, stride=1, bias=False),
        )

    def forward(self, x):
        y1 = self.branch1(x)
        y2 = self.branch2(x)
        return torch.cat((y1, y2), dim=1)

However, the following module works just fine (though the output of this one isn’t equivalent to Broken):

class Working(nn.Module):
    def __init__(self):
        super().__init__()
        self.branch1 = nn.Conv3d(1, 1, kernel_size=1, stride=1, bias=False)
        self.branch2 = nn.Conv3d(1, 1, kernel_size=1, stride=1, bias=False)
        self.branch3 = nn.Sequential(
            nn.Conv3d(1, 1, kernel_size=1, stride=1, bias=False),
            nn.Conv3d(1, 1, kernel_size=1, stride=1, bias=False)
        )

    def forward(self, x):
        return self.branch1(x)
        y1 = self.branch1(x)
        y2 = self.branch2(x)
        y3 = self.branch3(x)
        return torch.cat((y1, y2, y3), dim=1)

Once I figured out what was causing the error to occur I was able to come up with the following pattern which produces a module equivalent to Broken but this one works with TensorRT:

class BrokenFix(nn.Module):
    def __init__(self):
        super().__init__()
        # branch1
        self.c1 = nn.Conv3d(1, 1, kernel_size=1, stride=1, bias=False)
        self.c2 = nn.Conv3d(1, 1, kernel_size=1, stride=1, bias=False)
        # branch2
        self.c3 = nn.Conv3d(1, 1, kernel_size=1, stride=1, bias=False)
        self.c4 = nn.Conv3d(1, 1, kernel_size=1, stride=1, bias=False)

    def forward(self, x):
        x1 = self.c1(x)
        x2 = self.c3(x)
        # the next 3 lines are mathematically equivalent to a no-op but force TRT to join
        # branches and split again which avoids an error
        x12 = torch.cat((x1, x2), dim=1)
        x1 = x12[:, :1, :, :, :]
        x2 = x12[:, 1:, :, :, :]
        y1 = self.c2(x1)
        y2 = self.c4(x2)
        return torch.cat((y1, y2), dim=1)

Basically I force branch1 and branch2 to merge (via y = torch.cat) then split again using slicing. I implemented this pattern in all the Mixed* layers in the I3D network and now I’ve been able to successfully export my model to TensorRT.

I don’t see how this is simply a failure to support certain operations. This seems like a bug in TensorRT to me.

The issue might be due to limited 3D support in TRT 7.
For now, you can try to create a plugin with multiple outputs, and do the parallel branches inside the plugin.

Thanks