Cuda OutOfMemory when creating tensor with 2^29 (~0.5 G) elements

Description

When there is a tensor with 2^29 elements in an onnx graph, running trtexec --onnx=output.onnx --verbose will fail with error message [02/09/2022-22:02:48] [TRT] [W] Requested amount of GPU memory (18446744069414584320 bytes) could not be allocated. There may not be enough free memory for allocation to succeed., but there is actually enough GPU memory (>=7GB) to create such a tensor.

A hypothesis

I dug a bit and it looks like the problem is caused by an overflow in multiplying the tensor size by 8 in int32 and then casting to int64. For example, the large number 18446744069414584320=0xFFFFFFFF00000000 is by an overflow from computing 2^29 * 8 in int32 followed by casting into int64. And when I change 2^29 to 2^29+2^3, the large number also changes to 18446744069414584384=0xFFFFFFFF00000040, which indeed verifies the hypothesis.

An example model:

image
Can be generated with the following code:

import numpy as np
import onnx
import onnx.checker
import torch


inputs = {
    'a': np.random.rand(2**29) > 0.5,
    # 'a': np.random.rand(2**29).astype(np.float32), # same error with this line
}


class Model(torch.nn.Module):
    @torch.no_grad()
    def forward(self, a):
        return (a,)


model = Model()
model.eval()
torch_inp = {k: torch.from_numpy(v) for k, v in inputs.items()}
torch.onnx.export(
    model, tuple(torch_inp.values()),
    "output.onnx", input_names=list(inputs.keys()),
    verbose=False, opset_version=14)

onnx_model = onnx.load("output.onnx")
onnx.checker.check_model(onnx_model, full_check=True)

Environment

TensorRT Version: v8201
GPU Type: RTX 2080
Nvidia Driver Version: 495.29.05
CUDA Version: 11.5
CUDNN Version: 8.3.0
Operating System + Version: Ubuntu 18.04
Python Version (if applicable): 3.8.10
TensorFlow Version (if applicable):
PyTorch Version (if applicable): 1.10.2+cpu
Baremetal or Container (if container which image + tag): Container based on nvcr.io/nvidia/tensorrt:21.11-py3

Relevant Files

output.onnx.zip (287 Bytes)

Steps To Reproduce

trtexec --onnx=output.onnx --verbose

The output:

&&&& RUNNING TensorRT.trtexec [TensorRT v8201] # trtexec --onnx=output.onnx --verbose
[02/09/2022-23:34:46] [I] === Model Options ===
[02/09/2022-23:34:46] [I] Format: ONNX
[02/09/2022-23:34:46] [I] Model: output.onnx
[02/09/2022-23:34:46] [I] Output:
[02/09/2022-23:34:46] [I] === Build Options ===
[02/09/2022-23:34:46] [I] Max batch: explicit batch
[02/09/2022-23:34:46] [I] Workspace: 16 MiB
[02/09/2022-23:34:46] [I] minTiming: 1
[02/09/2022-23:34:46] [I] avgTiming: 8
[02/09/2022-23:34:46] [I] Precision: FP32
[02/09/2022-23:34:46] [I] Calibration: 
[02/09/2022-23:34:46] [I] Refit: Disabled
[02/09/2022-23:34:46] [I] Sparsity: Disabled
[02/09/2022-23:34:46] [I] Safe mode: Disabled
[02/09/2022-23:34:46] [I] DirectIO mode: Disabled
[02/09/2022-23:34:46] [I] Restricted mode: Disabled
[02/09/2022-23:34:46] [I] Save engine: 
[02/09/2022-23:34:46] [I] Load engine: 
[02/09/2022-23:34:46] [I] Profiling verbosity: 0
[02/09/2022-23:34:46] [I] Tactic sources: Using default tactic sources
[02/09/2022-23:34:46] [I] timingCacheMode: local
[02/09/2022-23:34:46] [I] timingCacheFile: 
[02/09/2022-23:34:46] [I] Input(s)s format: fp32:CHW
[02/09/2022-23:34:46] [I] Output(s)s format: fp32:CHW
[02/09/2022-23:34:46] [I] Input build shapes: model
[02/09/2022-23:34:46] [I] Input calibration shapes: model
[02/09/2022-23:34:46] [I] === System Options ===
[02/09/2022-23:34:46] [I] Device: 0
[02/09/2022-23:34:46] [I] DLACore: 
[02/09/2022-23:34:46] [I] Plugins:
[02/09/2022-23:34:46] [I] === Inference Options ===
[02/09/2022-23:34:46] [I] Batch: Explicit
[02/09/2022-23:34:46] [I] Input inference shapes: model
[02/09/2022-23:34:46] [I] Iterations: 10
[02/09/2022-23:34:46] [I] Duration: 3s (+ 200ms warm up)
[02/09/2022-23:34:46] [I] Sleep time: 0ms
[02/09/2022-23:34:46] [I] Idle time: 0ms
[02/09/2022-23:34:46] [I] Streams: 1
[02/09/2022-23:34:46] [I] ExposeDMA: Disabled
[02/09/2022-23:34:46] [I] Data transfers: Enabled
[02/09/2022-23:34:46] [I] Spin-wait: Disabled
[02/09/2022-23:34:46] [I] Multithreading: Disabled
[02/09/2022-23:34:46] [I] CUDA Graph: Disabled
[02/09/2022-23:34:46] [I] Separate profiling: Disabled
[02/09/2022-23:34:46] [I] Time Deserialize: Disabled
[02/09/2022-23:34:46] [I] Time Refit: Disabled
[02/09/2022-23:34:46] [I] Skip inference: Disabled
[02/09/2022-23:34:46] [I] Inputs:
[02/09/2022-23:34:46] [I] === Reporting Options ===
[02/09/2022-23:34:46] [I] Verbose: Enabled
[02/09/2022-23:34:46] [I] Averages: 10 inferences
[02/09/2022-23:34:46] [I] Percentile: 99
[02/09/2022-23:34:46] [I] Dump refittable layers:Disabled
[02/09/2022-23:34:46] [I] Dump output: Disabled
[02/09/2022-23:34:46] [I] Profile: Disabled
[02/09/2022-23:34:46] [I] Export timing to JSON file: 
[02/09/2022-23:34:46] [I] Export output to JSON file: 
[02/09/2022-23:34:46] [I] Export profile to JSON file: 
[02/09/2022-23:34:46] [I] 
[02/09/2022-23:34:46] [I] === Device Information ===
[02/09/2022-23:34:46] [I] Selected Device: NVIDIA GeForce RTX 2080
[02/09/2022-23:34:46] [I] Compute Capability: 7.5
[02/09/2022-23:34:46] [I] SMs: 46
[02/09/2022-23:34:46] [I] Compute Clock Rate: 1.71 GHz
[02/09/2022-23:34:46] [I] Device Global Memory: 7982 MiB
[02/09/2022-23:34:46] [I] Shared Memory per SM: 64 KiB
[02/09/2022-23:34:46] [I] Memory Bus Width: 256 bits (ECC disabled)
[02/09/2022-23:34:46] [I] Memory Clock Rate: 7 GHz
[02/09/2022-23:34:46] [I] 
[02/09/2022-23:34:46] [I] TensorRT version: 8.2.1
[02/09/2022-23:34:46] [V] [TRT] Registered plugin creator - ::GridAnchor_TRT version 1
[02/09/2022-23:34:46] [V] [TRT] Registered plugin creator - ::GridAnchorRect_TRT version 1
[02/09/2022-23:34:46] [V] [TRT] Registered plugin creator - ::NMS_TRT version 1
[02/09/2022-23:34:46] [V] [TRT] Registered plugin creator - ::Reorg_TRT version 1
[02/09/2022-23:34:46] [V] [TRT] Registered plugin creator - ::Region_TRT version 1
[02/09/2022-23:34:46] [V] [TRT] Registered plugin creator - ::Clip_TRT version 1
[02/09/2022-23:34:46] [V] [TRT] Registered plugin creator - ::LReLU_TRT version 1
[02/09/2022-23:34:46] [V] [TRT] Registered plugin creator - ::PriorBox_TRT version 1
[02/09/2022-23:34:46] [V] [TRT] Registered plugin creator - ::Normalize_TRT version 1
[02/09/2022-23:34:46] [V] [TRT] Registered plugin creator - ::ScatterND version 1
[02/09/2022-23:34:46] [V] [TRT] Registered plugin creator - ::RPROI_TRT version 1
[02/09/2022-23:34:46] [V] [TRT] Registered plugin creator - ::BatchedNMS_TRT version 1
[02/09/2022-23:34:46] [V] [TRT] Registered plugin creator - ::BatchedNMSDynamic_TRT version 1
[02/09/2022-23:34:46] [V] [TRT] Registered plugin creator - ::FlattenConcat_TRT version 1
[02/09/2022-23:34:46] [V] [TRT] Registered plugin creator - ::CropAndResize version 1
[02/09/2022-23:34:46] [V] [TRT] Registered plugin creator - ::DetectionLayer_TRT version 1
[02/09/2022-23:34:46] [V] [TRT] Registered plugin creator - ::EfficientNMS_TRT version 1
[02/09/2022-23:34:46] [V] [TRT] Registered plugin creator - ::EfficientNMS_ONNX_TRT version 1
[02/09/2022-23:34:46] [V] [TRT] Registered plugin creator - ::EfficientNMS_TFTRT_TRT version 1
[02/09/2022-23:34:46] [V] [TRT] Registered plugin creator - ::Proposal version 1
[02/09/2022-23:34:46] [V] [TRT] Registered plugin creator - ::ProposalLayer_TRT version 1
[02/09/2022-23:34:46] [V] [TRT] Registered plugin creator - ::PyramidROIAlign_TRT version 1
[02/09/2022-23:34:46] [V] [TRT] Registered plugin creator - ::ResizeNearest_TRT version 1
[02/09/2022-23:34:46] [V] [TRT] Registered plugin creator - ::Split version 1
[02/09/2022-23:34:46] [V] [TRT] Registered plugin creator - ::SpecialSlice_TRT version 1
[02/09/2022-23:34:46] [V] [TRT] Registered plugin creator - ::InstanceNormalization_TRT version 1
[02/09/2022-23:34:47] [I] [TRT] [MemUsageChange] Init CUDA: CPU +321, GPU +0, now: CPU 333, GPU 265 (MiB)
[02/09/2022-23:34:47] [I] [TRT] [MemUsageSnapshot] Begin constructing builder kernel library: CPU 333 MiB, GPU 265 MiB
[02/09/2022-23:34:47] [I] [TRT] [MemUsageSnapshot] End constructing builder kernel library: CPU 468 MiB, GPU 299 MiB
[02/09/2022-23:34:47] [I] Start parsing network model
[02/09/2022-23:34:47] [I] [TRT] ----------------------------------------------------------------
[02/09/2022-23:34:47] [I] [TRT] Input filename:   output.onnx
[02/09/2022-23:34:47] [I] [TRT] ONNX IR version:  0.0.7
[02/09/2022-23:34:47] [I] [TRT] Opset version:    14
[02/09/2022-23:34:47] [I] [TRT] Producer name:    pytorch
[02/09/2022-23:34:47] [I] [TRT] Producer version: 1.10
[02/09/2022-23:34:47] [I] [TRT] Domain:           
[02/09/2022-23:34:47] [I] [TRT] Model version:    0
[02/09/2022-23:34:47] [I] [TRT] Doc string:       
[02/09/2022-23:34:47] [I] [TRT] ----------------------------------------------------------------
[02/09/2022-23:34:47] [V] [TRT] Plugin creator already registered - ::GridAnchor_TRT version 1
[02/09/2022-23:34:47] [V] [TRT] Plugin creator already registered - ::GridAnchorRect_TRT version 1
[02/09/2022-23:34:47] [V] [TRT] Plugin creator already registered - ::NMS_TRT version 1
[02/09/2022-23:34:47] [V] [TRT] Plugin creator already registered - ::Reorg_TRT version 1
[02/09/2022-23:34:47] [V] [TRT] Plugin creator already registered - ::Region_TRT version 1
[02/09/2022-23:34:47] [V] [TRT] Plugin creator already registered - ::Clip_TRT version 1
[02/09/2022-23:34:47] [V] [TRT] Plugin creator already registered - ::LReLU_TRT version 1
[02/09/2022-23:34:47] [V] [TRT] Plugin creator already registered - ::PriorBox_TRT version 1
[02/09/2022-23:34:47] [V] [TRT] Plugin creator already registered - ::Normalize_TRT version 1
[02/09/2022-23:34:47] [V] [TRT] Plugin creator already registered - ::ScatterND version 1
[02/09/2022-23:34:47] [V] [TRT] Plugin creator already registered - ::RPROI_TRT version 1
[02/09/2022-23:34:47] [V] [TRT] Plugin creator already registered - ::BatchedNMS_TRT version 1
[02/09/2022-23:34:47] [V] [TRT] Plugin creator already registered - ::BatchedNMSDynamic_TRT version 1
[02/09/2022-23:34:47] [V] [TRT] Plugin creator already registered - ::FlattenConcat_TRT version 1
[02/09/2022-23:34:47] [V] [TRT] Plugin creator already registered - ::CropAndResize version 1
[02/09/2022-23:34:47] [V] [TRT] Plugin creator already registered - ::DetectionLayer_TRT version 1
[02/09/2022-23:34:47] [V] [TRT] Plugin creator already registered - ::EfficientNMS_TRT version 1
[02/09/2022-23:34:47] [V] [TRT] Plugin creator already registered - ::EfficientNMS_ONNX_TRT version 1
[02/09/2022-23:34:47] [V] [TRT] Plugin creator already registered - ::EfficientNMS_TFTRT_TRT version 1
[02/09/2022-23:34:47] [V] [TRT] Plugin creator already registered - ::Proposal version 1
[02/09/2022-23:34:47] [V] [TRT] Plugin creator already registered - ::ProposalLayer_TRT version 1
[02/09/2022-23:34:47] [V] [TRT] Plugin creator already registered - ::PyramidROIAlign_TRT version 1
[02/09/2022-23:34:47] [V] [TRT] Plugin creator already registered - ::ResizeNearest_TRT version 1
[02/09/2022-23:34:47] [V] [TRT] Plugin creator already registered - ::Split version 1
[02/09/2022-23:34:47] [V] [TRT] Plugin creator already registered - ::SpecialSlice_TRT version 1
[02/09/2022-23:34:47] [V] [TRT] Plugin creator already registered - ::InstanceNormalization_TRT version 1
[02/09/2022-23:34:47] [V] [TRT] Adding network input: a with dtype: bool, dimensions: (536870912)
[02/09/2022-23:34:47] [V] [TRT] Registering tensor: a for ONNX tensor: a
[02/09/2022-23:34:47] [V] [TRT] Parsing node: Identity_0 [Identity]
[02/09/2022-23:34:47] [V] [TRT] Searching for input: a
[02/09/2022-23:34:47] [V] [TRT] Identity_0 [Identity] inputs: [a -> (536870912)[BOOL]], 
[02/09/2022-23:34:47] [V] [TRT] Registering layer: Identity_0 for ONNX node: Identity_0
[02/09/2022-23:34:47] [V] [TRT] Registering tensor: 1_0 for ONNX tensor: 1
[02/09/2022-23:34:47] [V] [TRT] Identity_0 [Identity] outputs: [1 -> (536870912)[BOOL]], 
[02/09/2022-23:34:47] [V] [TRT] Marking 1_0 as output: 1
[02/09/2022-23:34:47] [I] Finish parsing network model
[02/09/2022-23:34:47] [V] [TRT] Applying generic optimizations to the graph for inference.
[02/09/2022-23:34:47] [V] [TRT] Original: 1 layers
[02/09/2022-23:34:47] [V] [TRT] After dead-layer removal: 1 layers
[02/09/2022-23:34:47] [V] [TRT] After Myelin optimization: 1 layers
[02/09/2022-23:34:47] [V] [TRT] Applying ScaleNodes fusions.
[02/09/2022-23:34:47] [V] [TRT] After scale fusion: 1 layers
[02/09/2022-23:34:47] [V] [TRT] After vertical fusions: 1 layers
[02/09/2022-23:34:47] [V] [TRT] After dupe layer removal: 1 layers
[02/09/2022-23:34:47] [V] [TRT] After final dead-layer removal: 1 layers
[02/09/2022-23:34:47] [V] [TRT] After tensor merging: 1 layers
[02/09/2022-23:34:47] [V] [TRT] After concat removal: 1 layers
[02/09/2022-23:34:47] [V] [TRT] Graph construction and optimization completed in 0.00082093 seconds.
[02/09/2022-23:34:47] [V] [TRT] Using cublasLt as a tactic source
[02/09/2022-23:34:47] [I] [TRT] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +508, GPU +222, now: CPU 976, GPU 521 (MiB)
[02/09/2022-23:34:47] [V] [TRT] Using cuDNN as a tactic source
[02/09/2022-23:34:47] [I] [TRT] [MemUsageChange] Init cuDNN: CPU +113, GPU +52, now: CPU 1089, GPU 573 (MiB)
[02/09/2022-23:34:47] [I] [TRT] Local timing cache in use. Profiling results in this builder pass will not be stored.
[02/09/2022-23:34:47] [V] [TRT] Constructing optimization profile number 0 [1/1].
[02/09/2022-23:34:47] [E] Error[1]: [resizingAllocator.cpp::allocate::61] Error Code 1: Cuda Runtime (out of memory)
[02/09/2022-23:34:47] [W] [TRT] -------------- The current system memory allocations dump as below --------------
-------------- The current device memory allocations dump as below --------------
[0]:18446744069414584320 :DeviceActivationSize in reserveNetworkTensorMemory: at optimizer/common/tactic/optimizer.cpp: 4596 idx: 8 time: 0.000104068
[02/09/2022-23:34:47] [W] [TRT] Requested amount of GPU memory (18446744069414584320 bytes) could not be allocated. There may not be enough free memory for allocation to succeed.
[02/09/2022-23:34:47] [E] Error[2]: [optimizer.cpp::reserveNetworkTensorMemory::4596] Error Code 2: OutOfMemory (no further information)
[02/09/2022-23:34:47] [E] Error[2]: [builder.cpp::buildSerializedNetwork::609] Error Code 2: Internal Error (Assertion enginePtr != nullptr failed. )
[02/09/2022-23:34:47] [E] Engine could not be created from network
[02/09/2022-23:34:47] [E] Building engine failed
[02/09/2022-23:34:47] [E] Failed to create engine from model.
[02/09/2022-23:34:47] [E] Engine set up failed
&&&& FAILED TensorRT.trtexec [TensorRT v8201] # trtexec --onnx=output.onnx --verbose

Hi @user33875,

This is known issue, please upgrade to 8.2.3 when it is released. This will be fixed in the future releases.

Thank you.

1 Like

i installed 8.2.3,but the problem remains.H and W are dynamic size, if H and W are set very small,no error;if if H and W are set big , error occurred.
[02/21/2022-01:48:56] [TRT] [W] Requested amount of GPU memory (32768000000 bytes) could not be allocated. There may not be enough free memory for allocation to succeed.
[02/21/2022-01:48:56] [TRT] [W] Skipping tactic 25 due to insuficient memory on requested size of 32768000000 detected for tactic 29.
Try decreasing the workspace size with IBuilderConfig::setMaxWorkspaceSize().
[02/21/2022-01:48:56] [TRT] [E] 1: [resizingAllocator.cpp::allocate::61] Error Code 1: Cuda Runtime (out of memory)

Could you please open a new topic and share with us the issue repro onnx model and completed verbose logs for better debugging.

Thank you.

Hi @spolisetty @1752681239, I just tried v8.2.3 too, but it still failed with the same error message. Below is the log of runnign trtexec --onnx=output.onnx --verbose on the same model:

&&&& RUNNING TensorRT.trtexec [TensorRT v8203] # trtexec --onnx=output.onnx --verbose
[02/21/2022-14:35:22] [I] === Model Options ===
[02/21/2022-14:35:22] [I] Format: ONNX
[02/21/2022-14:35:22] [I] Model: output.onnx
[02/21/2022-14:35:22] [I] Output:
[02/21/2022-14:35:22] [I] === Build Options ===
[02/21/2022-14:35:22] [I] Max batch: explicit batch
[02/21/2022-14:35:22] [I] Workspace: 16 MiB
[02/21/2022-14:35:22] [I] minTiming: 1
[02/21/2022-14:35:22] [I] avgTiming: 8
[02/21/2022-14:35:22] [I] Precision: FP32
[02/21/2022-14:35:22] [I] Calibration: 
[02/21/2022-14:35:22] [I] Refit: Disabled
[02/21/2022-14:35:22] [I] Sparsity: Disabled
[02/21/2022-14:35:22] [I] Safe mode: Disabled
[02/21/2022-14:35:22] [I] DirectIO mode: Disabled
[02/21/2022-14:35:22] [I] Restricted mode: Disabled
[02/21/2022-14:35:22] [I] Save engine: 
[02/21/2022-14:35:22] [I] Load engine: 
[02/21/2022-14:35:22] [I] Profiling verbosity: 0
[02/21/2022-14:35:22] [I] Tactic sources: Using default tactic sources
[02/21/2022-14:35:22] [I] timingCacheMode: local
[02/21/2022-14:35:22] [I] timingCacheFile: 
[02/21/2022-14:35:22] [I] Input(s)s format: fp32:CHW
[02/21/2022-14:35:22] [I] Output(s)s format: fp32:CHW
[02/21/2022-14:35:22] [I] Input build shapes: model
[02/21/2022-14:35:22] [I] Input calibration shapes: model
[02/21/2022-14:35:22] [I] === System Options ===
[02/21/2022-14:35:22] [I] Device: 0
[02/21/2022-14:35:22] [I] DLACore: 
[02/21/2022-14:35:22] [I] Plugins:
[02/21/2022-14:35:22] [I] === Inference Options ===
[02/21/2022-14:35:22] [I] Batch: Explicit
[02/21/2022-14:35:22] [I] Input inference shapes: model
[02/21/2022-14:35:22] [I] Iterations: 10
[02/21/2022-14:35:22] [I] Duration: 3s (+ 200ms warm up)
[02/21/2022-14:35:22] [I] Sleep time: 0ms
[02/21/2022-14:35:22] [I] Idle time: 0ms
[02/21/2022-14:35:22] [I] Streams: 1
[02/21/2022-14:35:22] [I] ExposeDMA: Disabled
[02/21/2022-14:35:22] [I] Data transfers: Enabled
[02/21/2022-14:35:22] [I] Spin-wait: Disabled
[02/21/2022-14:35:22] [I] Multithreading: Disabled
[02/21/2022-14:35:22] [I] CUDA Graph: Disabled
[02/21/2022-14:35:22] [I] Separate profiling: Disabled
[02/21/2022-14:35:22] [I] Time Deserialize: Disabled
[02/21/2022-14:35:22] [I] Time Refit: Disabled
[02/21/2022-14:35:22] [I] Skip inference: Disabled
[02/21/2022-14:35:22] [I] Inputs:
[02/21/2022-14:35:22] [I] === Reporting Options ===
[02/21/2022-14:35:22] [I] Verbose: Enabled
[02/21/2022-14:35:22] [I] Averages: 10 inferences
[02/21/2022-14:35:22] [I] Percentile: 99
[02/21/2022-14:35:22] [I] Dump refittable layers:Disabled
[02/21/2022-14:35:22] [I] Dump output: Disabled
[02/21/2022-14:35:22] [I] Profile: Disabled
[02/21/2022-14:35:22] [I] Export timing to JSON file: 
[02/21/2022-14:35:22] [I] Export output to JSON file: 
[02/21/2022-14:35:22] [I] Export profile to JSON file: 
[02/21/2022-14:35:22] [I] 
[02/21/2022-14:35:22] [I] === Device Information ===
[02/21/2022-14:35:22] [I] Selected Device: NVIDIA GeForce RTX 2080
[02/21/2022-14:35:22] [I] Compute Capability: 7.5
[02/21/2022-14:35:22] [I] SMs: 46
[02/21/2022-14:35:22] [I] Compute Clock Rate: 1.71 GHz
[02/21/2022-14:35:22] [I] Device Global Memory: 7982 MiB
[02/21/2022-14:35:22] [I] Shared Memory per SM: 64 KiB
[02/21/2022-14:35:22] [I] Memory Bus Width: 256 bits (ECC disabled)
[02/21/2022-14:35:22] [I] Memory Clock Rate: 7 GHz
[02/21/2022-14:35:22] [I] 
[02/21/2022-14:35:22] [I] TensorRT version: 8.2.3
[02/21/2022-14:35:22] [V] [TRT] Registered plugin creator - ::GridAnchor_TRT version 1
[02/21/2022-14:35:22] [V] [TRT] Registered plugin creator - ::GridAnchorRect_TRT version 1
[02/21/2022-14:35:22] [V] [TRT] Registered plugin creator - ::NMS_TRT version 1
[02/21/2022-14:35:22] [V] [TRT] Registered plugin creator - ::Reorg_TRT version 1
[02/21/2022-14:35:22] [V] [TRT] Registered plugin creator - ::Region_TRT version 1
[02/21/2022-14:35:22] [V] [TRT] Registered plugin creator - ::Clip_TRT version 1
[02/21/2022-14:35:22] [V] [TRT] Registered plugin creator - ::LReLU_TRT version 1
[02/21/2022-14:35:22] [V] [TRT] Registered plugin creator - ::PriorBox_TRT version 1
[02/21/2022-14:35:22] [V] [TRT] Registered plugin creator - ::Normalize_TRT version 1
[02/21/2022-14:35:22] [V] [TRT] Registered plugin creator - ::ScatterND version 1
[02/21/2022-14:35:22] [V] [TRT] Registered plugin creator - ::RPROI_TRT version 1
[02/21/2022-14:35:22] [V] [TRT] Registered plugin creator - ::BatchedNMS_TRT version 1
[02/21/2022-14:35:22] [V] [TRT] Registered plugin creator - ::BatchedNMSDynamic_TRT version 1
[02/21/2022-14:35:22] [V] [TRT] Registered plugin creator - ::FlattenConcat_TRT version 1
[02/21/2022-14:35:22] [V] [TRT] Registered plugin creator - ::CropAndResize version 1
[02/21/2022-14:35:22] [V] [TRT] Registered plugin creator - ::DetectionLayer_TRT version 1
[02/21/2022-14:35:22] [V] [TRT] Registered plugin creator - ::EfficientNMS_TRT version 1
[02/21/2022-14:35:22] [V] [TRT] Registered plugin creator - ::EfficientNMS_ONNX_TRT version 1
[02/21/2022-14:35:22] [V] [TRT] Registered plugin creator - ::EfficientNMS_TFTRT_TRT version 1
[02/21/2022-14:35:22] [V] [TRT] Registered plugin creator - ::Proposal version 1
[02/21/2022-14:35:22] [V] [TRT] Registered plugin creator - ::ProposalLayer_TRT version 1
[02/21/2022-14:35:22] [V] [TRT] Registered plugin creator - ::PyramidROIAlign_TRT version 1
[02/21/2022-14:35:22] [V] [TRT] Registered plugin creator - ::ResizeNearest_TRT version 1
[02/21/2022-14:35:22] [V] [TRT] Registered plugin creator - ::Split version 1
[02/21/2022-14:35:22] [V] [TRT] Registered plugin creator - ::SpecialSlice_TRT version 1
[02/21/2022-14:35:22] [V] [TRT] Registered plugin creator - ::InstanceNormalization_TRT version 1
[02/21/2022-14:35:23] [I] [TRT] [MemUsageChange] Init CUDA: CPU +321, GPU +0, now: CPU 333, GPU 265 (MiB)
[02/21/2022-14:35:23] [I] [TRT] [MemUsageSnapshot] Begin constructing builder kernel library: CPU 333 MiB, GPU 265 MiB
[02/21/2022-14:35:23] [I] [TRT] [MemUsageSnapshot] End constructing builder kernel library: CPU 468 MiB, GPU 299 MiB
[02/21/2022-14:35:23] [I] Start parsing network model
[02/21/2022-14:35:23] [I] [TRT] ----------------------------------------------------------------
[02/21/2022-14:35:23] [I] [TRT] Input filename:   output.onnx
[02/21/2022-14:35:23] [I] [TRT] ONNX IR version:  0.0.7
[02/21/2022-14:35:23] [I] [TRT] Opset version:    14
[02/21/2022-14:35:23] [I] [TRT] Producer name:    pytorch
[02/21/2022-14:35:23] [I] [TRT] Producer version: 1.10
[02/21/2022-14:35:23] [I] [TRT] Domain:           
[02/21/2022-14:35:23] [I] [TRT] Model version:    0
[02/21/2022-14:35:23] [I] [TRT] Doc string:       
[02/21/2022-14:35:23] [I] [TRT] ----------------------------------------------------------------
[02/21/2022-14:35:23] [V] [TRT] Plugin creator already registered - ::GridAnchor_TRT version 1
[02/21/2022-14:35:23] [V] [TRT] Plugin creator already registered - ::GridAnchorRect_TRT version 1
[02/21/2022-14:35:23] [V] [TRT] Plugin creator already registered - ::NMS_TRT version 1
[02/21/2022-14:35:23] [V] [TRT] Plugin creator already registered - ::Reorg_TRT version 1
[02/21/2022-14:35:23] [V] [TRT] Plugin creator already registered - ::Region_TRT version 1
[02/21/2022-14:35:23] [V] [TRT] Plugin creator already registered - ::Clip_TRT version 1
[02/21/2022-14:35:23] [V] [TRT] Plugin creator already registered - ::LReLU_TRT version 1
[02/21/2022-14:35:23] [V] [TRT] Plugin creator already registered - ::PriorBox_TRT version 1
[02/21/2022-14:35:23] [V] [TRT] Plugin creator already registered - ::Normalize_TRT version 1
[02/21/2022-14:35:23] [V] [TRT] Plugin creator already registered - ::ScatterND version 1
[02/21/2022-14:35:23] [V] [TRT] Plugin creator already registered - ::RPROI_TRT version 1
[02/21/2022-14:35:23] [V] [TRT] Plugin creator already registered - ::BatchedNMS_TRT version 1
[02/21/2022-14:35:23] [V] [TRT] Plugin creator already registered - ::BatchedNMSDynamic_TRT version 1
[02/21/2022-14:35:23] [V] [TRT] Plugin creator already registered - ::FlattenConcat_TRT version 1
[02/21/2022-14:35:23] [V] [TRT] Plugin creator already registered - ::CropAndResize version 1
[02/21/2022-14:35:23] [V] [TRT] Plugin creator already registered - ::DetectionLayer_TRT version 1
[02/21/2022-14:35:23] [V] [TRT] Plugin creator already registered - ::EfficientNMS_TRT version 1
[02/21/2022-14:35:23] [V] [TRT] Plugin creator already registered - ::EfficientNMS_ONNX_TRT version 1
[02/21/2022-14:35:23] [V] [TRT] Plugin creator already registered - ::EfficientNMS_TFTRT_TRT version 1
[02/21/2022-14:35:23] [V] [TRT] Plugin creator already registered - ::Proposal version 1
[02/21/2022-14:35:23] [V] [TRT] Plugin creator already registered - ::ProposalLayer_TRT version 1
[02/21/2022-14:35:23] [V] [TRT] Plugin creator already registered - ::PyramidROIAlign_TRT version 1
[02/21/2022-14:35:23] [V] [TRT] Plugin creator already registered - ::ResizeNearest_TRT version 1
[02/21/2022-14:35:23] [V] [TRT] Plugin creator already registered - ::Split version 1
[02/21/2022-14:35:23] [V] [TRT] Plugin creator already registered - ::SpecialSlice_TRT version 1
[02/21/2022-14:35:23] [V] [TRT] Plugin creator already registered - ::InstanceNormalization_TRT version 1
[02/21/2022-14:35:23] [V] [TRT] Adding network input: a with dtype: bool, dimensions: (536870912)
[02/21/2022-14:35:23] [V] [TRT] Registering tensor: a for ONNX tensor: a
[02/21/2022-14:35:23] [V] [TRT] Parsing node: Identity_0 [Identity]
[02/21/2022-14:35:23] [V] [TRT] Searching for input: a
[02/21/2022-14:35:23] [V] [TRT] Identity_0 [Identity] inputs: [a -> (536870912)[BOOL]], 
[02/21/2022-14:35:23] [V] [TRT] Registering layer: Identity_0 for ONNX node: Identity_0
[02/21/2022-14:35:23] [V] [TRT] Registering tensor: 1_0 for ONNX tensor: 1
[02/21/2022-14:35:23] [V] [TRT] Identity_0 [Identity] outputs: [1 -> (536870912)[BOOL]], 
[02/21/2022-14:35:23] [V] [TRT] Marking 1_0 as output: 1
[02/21/2022-14:35:23] [I] Finish parsing network model
[02/21/2022-14:35:23] [V] [TRT] Applying generic optimizations to the graph for inference.
[02/21/2022-14:35:23] [V] [TRT] Original: 1 layers
[02/21/2022-14:35:23] [V] [TRT] After dead-layer removal: 1 layers
[02/21/2022-14:35:23] [V] [TRT] After Myelin optimization: 1 layers
[02/21/2022-14:35:23] [V] [TRT] Applying ScaleNodes fusions.
[02/21/2022-14:35:23] [V] [TRT] After scale fusion: 1 layers
[02/21/2022-14:35:23] [V] [TRT] After vertical fusions: 1 layers
[02/21/2022-14:35:23] [V] [TRT] After dupe layer removal: 1 layers
[02/21/2022-14:35:23] [V] [TRT] After final dead-layer removal: 1 layers
[02/21/2022-14:35:23] [V] [TRT] After tensor merging: 1 layers
[02/21/2022-14:35:23] [V] [TRT] After concat removal: 1 layers
[02/21/2022-14:35:23] [V] [TRT] Graph construction and optimization completed in 0.000729956 seconds.
[02/21/2022-14:35:23] [V] [TRT] Using cublasLt as a tactic source
[02/21/2022-14:35:23] [I] [TRT] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +508, GPU +222, now: CPU 976, GPU 521 (MiB)
[02/21/2022-14:35:23] [V] [TRT] Using cuDNN as a tactic source
[02/21/2022-14:35:24] [I] [TRT] [MemUsageChange] Init cuDNN: CPU +113, GPU +52, now: CPU 1089, GPU 573 (MiB)
[02/21/2022-14:35:24] [I] [TRT] Local timing cache in use. Profiling results in this builder pass will not be stored.
[02/21/2022-14:35:24] [V] [TRT] Constructing optimization profile number 0 [1/1].
[02/21/2022-14:35:24] [E] Error[1]: [resizingAllocator.cpp::allocate::61] Error Code 1: Cuda Runtime (out of memory)
[02/21/2022-14:35:24] [W] [TRT] -------------- The current system memory allocations dump as below --------------
-------------- The current device memory allocations dump as below --------------
[0]:18446744069414584320 :DeviceActivationSize in reserveNetworkTensorMemory: at optimizer/common/tactic/optimizer.cpp: 4602 idx: 8 time: 9.4124e-05
[02/21/2022-14:35:24] [W] [TRT] Requested amount of GPU memory (18446744069414584320 bytes) could not be allocated. There may not be enough free memory for allocation to succeed.
[02/21/2022-14:35:24] [E] Error[2]: [optimizer.cpp::reserveNetworkTensorMemory::4602] Error Code 2: OutOfMemory (no further information)
[02/21/2022-14:35:24] [E] Error[2]: [builder.cpp::buildSerializedNetwork::609] Error Code 2: Internal Error (Assertion enginePtr != nullptr failed. )
[02/21/2022-14:35:24] [E] Engine could not be created from network
[02/21/2022-14:35:24] [E] Building engine failed
[02/21/2022-14:35:24] [E] Failed to create engine from model.
[02/21/2022-14:35:24] [E] Engine set up failed
&&&& FAILED TensorRT.trtexec [TensorRT v8203] # trtexec --onnx=output.onnx --verbose

Hi,

With the latest TensorRT version 8.4 EA, we couldn’t reproduce this issue. It has been resolved.
Please try on the latest TensorRT version.
https://developer.nvidia.com/nvidia-tensorrt-8x-download

Thank you.

1 Like

This topic was automatically closed 14 days after the last reply. New replies are no longer allowed.