Duplicated reshapes triggers "[graphOptimizer.cpp::findOne::510] Error Code 2: Internal Error (Assertion it != v.end() failed. )"

Description

I am playing with tensorrt and onnx and got the error as in the title on an onnx model containing duplicated reshape ops, and the second reshape output is used more than once (or even twice? not sure). I guess the graph optimizer is doing something like common subexpression elimination, but somehow fails on this corner case of reshape. The model looks like this:
image

Can be generated by the following python code:

import numpy as np
import onnx
import onnx.checker
import torch


inputs = {
    'i3': np.random.rand(2, 1, 1, 2).astype(np.float32),
}

class Model(torch.nn.Module):
    @torch.no_grad()
    def forward(self, i3):
        a = i3.reshape(-1)
        b = i3.reshape(-1)
        c = b + b
        d = a + b
        return (d, c)
        # single output like below also triggers the bug
        # d = a + b / torch.mean(b * b)
        # return (d,)


model = Model()
model.eval()
torch_inp = {k: torch.from_numpy(v) for k, v in inputs.items()}
torch.onnx.export(
    model, tuple(torch_inp.values()),
    "output.onnx", input_names=list(inputs.keys()),
    verbose=False, opset_version=14)

onnx_model = onnx.load("output.onnx")
onnx.checker.check_model(onnx_model, full_check=True)

Not sure if it is worth checking this problem on this bizare model for TensorRT, but personally this reminds me of one use case (though not really in-produciton style), i.e., to compute self-attention in a brute force way, which triggers the same error:

class Model(torch.nn.Module):
    @torch.no_grad()
    def forward(self, i3):
        keys = [i3] # potentially more keys. Only put one here for simpicity
        cs = []
        for k1 in keys:
            for k2 in keys:
                a = k1.reshape(-1)
                b = k2.reshape(-1)
                # c = torch.sum(a * b) # the normal kernel using dot product passed though
                c = torch.sum(a + b / torch.mean(b * b)) # one might use other kernels, like the one I just made up
                cs.append(c)
        return cs

Environment

TensorRT Version : v8203
GPU Type : RTX 2080
Nvidia Driver Version : 495.29.05
CUDA Version : 11.5
CUDNN Version : 8.3.0
Operating System + Version : Ubuntu 18.04
Python Version (if applicable) : 3.8.10
TensorFlow Version (if applicable) :
PyTorch Version (if applicable) : 1.10.2+cpu
Baremetal or Container (if container which image + tag) : Container based on nvcr.io/nvidia/tensorrt:21.11-py3

Relevant Files

output.onnx.zip (417 Bytes)

Steps To Reproduce

Run trtexec --onnx=output.onnx --verbose.

The full output:

&&&& RUNNING TensorRT.trtexec [TensorRT v8203] # trtexec --onnx=output.onnx --verbose
[02/15/2022-16:00:38] [I] === Model Options ===
[02/15/2022-16:00:38] [I] Format: ONNX
[02/15/2022-16:00:38] [I] Model: output.onnx
[02/15/2022-16:00:38] [I] Output:
[02/15/2022-16:00:38] [I] === Build Options ===
[02/15/2022-16:00:38] [I] Max batch: explicit batch
[02/15/2022-16:00:38] [I] Workspace: 16 MiB
[02/15/2022-16:00:38] [I] minTiming: 1
[02/15/2022-16:00:38] [I] avgTiming: 8
[02/15/2022-16:00:38] [I] Precision: FP32
[02/15/2022-16:00:38] [I] Calibration: 
[02/15/2022-16:00:38] [I] Refit: Disabled
[02/15/2022-16:00:38] [I] Sparsity: Disabled
[02/15/2022-16:00:38] [I] Safe mode: Disabled
[02/15/2022-16:00:38] [I] DirectIO mode: Disabled
[02/15/2022-16:00:38] [I] Restricted mode: Disabled
[02/15/2022-16:00:38] [I] Save engine: 
[02/15/2022-16:00:38] [I] Load engine: 
[02/15/2022-16:00:38] [I] Profiling verbosity: 0
[02/15/2022-16:00:38] [I] Tactic sources: Using default tactic sources
[02/15/2022-16:00:38] [I] timingCacheMode: local
[02/15/2022-16:00:38] [I] timingCacheFile: 
[02/15/2022-16:00:38] [I] Input(s)s format: fp32:CHW
[02/15/2022-16:00:38] [I] Output(s)s format: fp32:CHW
[02/15/2022-16:00:38] [I] Input build shapes: model
[02/15/2022-16:00:38] [I] Input calibration shapes: model
[02/15/2022-16:00:38] [I] === System Options ===
[02/15/2022-16:00:38] [I] Device: 0
[02/15/2022-16:00:38] [I] DLACore: 
[02/15/2022-16:00:38] [I] Plugins:
[02/15/2022-16:00:38] [I] === Inference Options ===
[02/15/2022-16:00:38] [I] Batch: Explicit
[02/15/2022-16:00:38] [I] Input inference shapes: model
[02/15/2022-16:00:38] [I] Iterations: 10
[02/15/2022-16:00:38] [I] Duration: 3s (+ 200ms warm up)
[02/15/2022-16:00:38] [I] Sleep time: 0ms
[02/15/2022-16:00:38] [I] Idle time: 0ms
[02/15/2022-16:00:38] [I] Streams: 1
[02/15/2022-16:00:38] [I] ExposeDMA: Disabled
[02/15/2022-16:00:38] [I] Data transfers: Enabled
[02/15/2022-16:00:38] [I] Spin-wait: Disabled
[02/15/2022-16:00:38] [I] Multithreading: Disabled
[02/15/2022-16:00:38] [I] CUDA Graph: Disabled
[02/15/2022-16:00:38] [I] Separate profiling: Disabled
[02/15/2022-16:00:38] [I] Time Deserialize: Disabled
[02/15/2022-16:00:38] [I] Time Refit: Disabled
[02/15/2022-16:00:38] [I] Skip inference: Disabled
[02/15/2022-16:00:38] [I] Inputs:
[02/15/2022-16:00:38] [I] === Reporting Options ===
[02/15/2022-16:00:38] [I] Verbose: Enabled
[02/15/2022-16:00:38] [I] Averages: 10 inferences
[02/15/2022-16:00:38] [I] Percentile: 99
[02/15/2022-16:00:38] [I] Dump refittable layers:Disabled
[02/15/2022-16:00:38] [I] Dump output: Disabled
[02/15/2022-16:00:38] [I] Profile: Disabled
[02/15/2022-16:00:38] [I] Export timing to JSON file: 
[02/15/2022-16:00:38] [I] Export output to JSON file: 
[02/15/2022-16:00:38] [I] Export profile to JSON file: 
[02/15/2022-16:00:38] [I] 
[02/15/2022-16:00:38] [I] === Device Information ===
[02/15/2022-16:00:38] [I] Selected Device: NVIDIA GeForce RTX 2080
[02/15/2022-16:00:38] [I] Compute Capability: 7.5
[02/15/2022-16:00:38] [I] SMs: 46
[02/15/2022-16:00:38] [I] Compute Clock Rate: 1.71 GHz
[02/15/2022-16:00:38] [I] Device Global Memory: 7982 MiB
[02/15/2022-16:00:38] [I] Shared Memory per SM: 64 KiB
[02/15/2022-16:00:38] [I] Memory Bus Width: 256 bits (ECC disabled)
[02/15/2022-16:00:38] [I] Memory Clock Rate: 7 GHz
[02/15/2022-16:00:38] [I] 
[02/15/2022-16:00:38] [I] TensorRT version: 8.2.3
[02/15/2022-16:00:38] [V] [TRT] Registered plugin creator - ::GridAnchor_TRT version 1
[02/15/2022-16:00:38] [V] [TRT] Registered plugin creator - ::GridAnchorRect_TRT version 1
[02/15/2022-16:00:38] [V] [TRT] Registered plugin creator - ::NMS_TRT version 1
[02/15/2022-16:00:38] [V] [TRT] Registered plugin creator - ::Reorg_TRT version 1
[02/15/2022-16:00:38] [V] [TRT] Registered plugin creator - ::Region_TRT version 1
[02/15/2022-16:00:38] [V] [TRT] Registered plugin creator - ::Clip_TRT version 1
[02/15/2022-16:00:38] [V] [TRT] Registered plugin creator - ::LReLU_TRT version 1
[02/15/2022-16:00:38] [V] [TRT] Registered plugin creator - ::PriorBox_TRT version 1
[02/15/2022-16:00:38] [V] [TRT] Registered plugin creator - ::Normalize_TRT version 1
[02/15/2022-16:00:38] [V] [TRT] Registered plugin creator - ::ScatterND version 1
[02/15/2022-16:00:38] [V] [TRT] Registered plugin creator - ::RPROI_TRT version 1
[02/15/2022-16:00:38] [V] [TRT] Registered plugin creator - ::BatchedNMS_TRT version 1
[02/15/2022-16:00:38] [V] [TRT] Registered plugin creator - ::BatchedNMSDynamic_TRT version 1
[02/15/2022-16:00:38] [V] [TRT] Registered plugin creator - ::FlattenConcat_TRT version 1
[02/15/2022-16:00:38] [V] [TRT] Registered plugin creator - ::CropAndResize version 1
[02/15/2022-16:00:38] [V] [TRT] Registered plugin creator - ::DetectionLayer_TRT version 1
[02/15/2022-16:00:38] [V] [TRT] Registered plugin creator - ::EfficientNMS_TRT version 1
[02/15/2022-16:00:38] [V] [TRT] Registered plugin creator - ::EfficientNMS_ONNX_TRT version 1
[02/15/2022-16:00:38] [V] [TRT] Registered plugin creator - ::EfficientNMS_TFTRT_TRT version 1
[02/15/2022-16:00:38] [V] [TRT] Registered plugin creator - ::Proposal version 1
[02/15/2022-16:00:38] [V] [TRT] Registered plugin creator - ::ProposalLayer_TRT version 1
[02/15/2022-16:00:38] [V] [TRT] Registered plugin creator - ::PyramidROIAlign_TRT version 1
[02/15/2022-16:00:38] [V] [TRT] Registered plugin creator - ::ResizeNearest_TRT version 1
[02/15/2022-16:00:38] [V] [TRT] Registered plugin creator - ::Split version 1
[02/15/2022-16:00:38] [V] [TRT] Registered plugin creator - ::SpecialSlice_TRT version 1
[02/15/2022-16:00:38] [V] [TRT] Registered plugin creator - ::InstanceNormalization_TRT version 1
[02/15/2022-16:00:38] [I] [TRT] [MemUsageChange] Init CUDA: CPU +321, GPU +0, now: CPU 333, GPU 265 (MiB)
[02/15/2022-16:00:39] [I] [TRT] [MemUsageSnapshot] Begin constructing builder kernel library: CPU 333 MiB, GPU 265 MiB
[02/15/2022-16:00:39] [I] [TRT] [MemUsageSnapshot] End constructing builder kernel library: CPU 468 MiB, GPU 299 MiB
[02/15/2022-16:00:39] [I] Start parsing network model
[02/15/2022-16:00:39] [I] [TRT] ----------------------------------------------------------------
[02/15/2022-16:00:39] [I] [TRT] Input filename:   output.onnx
[02/15/2022-16:00:39] [I] [TRT] ONNX IR version:  0.0.7
[02/15/2022-16:00:39] [I] [TRT] Opset version:    14
[02/15/2022-16:00:39] [I] [TRT] Producer name:    pytorch
[02/15/2022-16:00:39] [I] [TRT] Producer version: 1.10
[02/15/2022-16:00:39] [I] [TRT] Domain:           
[02/15/2022-16:00:39] [I] [TRT] Model version:    0
[02/15/2022-16:00:39] [I] [TRT] Doc string:       
[02/15/2022-16:00:39] [I] [TRT] ----------------------------------------------------------------
[02/15/2022-16:00:39] [V] [TRT] Plugin creator already registered - ::GridAnchor_TRT version 1
[02/15/2022-16:00:39] [V] [TRT] Plugin creator already registered - ::GridAnchorRect_TRT version 1
[02/15/2022-16:00:39] [V] [TRT] Plugin creator already registered - ::NMS_TRT version 1
[02/15/2022-16:00:39] [V] [TRT] Plugin creator already registered - ::Reorg_TRT version 1
[02/15/2022-16:00:39] [V] [TRT] Plugin creator already registered - ::Region_TRT version 1
[02/15/2022-16:00:39] [V] [TRT] Plugin creator already registered - ::Clip_TRT version 1
[02/15/2022-16:00:39] [V] [TRT] Plugin creator already registered - ::LReLU_TRT version 1
[02/15/2022-16:00:39] [V] [TRT] Plugin creator already registered - ::PriorBox_TRT version 1
[02/15/2022-16:00:39] [V] [TRT] Plugin creator already registered - ::Normalize_TRT version 1
[02/15/2022-16:00:39] [V] [TRT] Plugin creator already registered - ::ScatterND version 1
[02/15/2022-16:00:39] [V] [TRT] Plugin creator already registered - ::RPROI_TRT version 1
[02/15/2022-16:00:39] [V] [TRT] Plugin creator already registered - ::BatchedNMS_TRT version 1
[02/15/2022-16:00:39] [V] [TRT] Plugin creator already registered - ::BatchedNMSDynamic_TRT version 1
[02/15/2022-16:00:39] [V] [TRT] Plugin creator already registered - ::FlattenConcat_TRT version 1
[02/15/2022-16:00:39] [V] [TRT] Plugin creator already registered - ::CropAndResize version 1
[02/15/2022-16:00:39] [V] [TRT] Plugin creator already registered - ::DetectionLayer_TRT version 1
[02/15/2022-16:00:39] [V] [TRT] Plugin creator already registered - ::EfficientNMS_TRT version 1
[02/15/2022-16:00:39] [V] [TRT] Plugin creator already registered - ::EfficientNMS_ONNX_TRT version 1
[02/15/2022-16:00:39] [V] [TRT] Plugin creator already registered - ::EfficientNMS_TFTRT_TRT version 1
[02/15/2022-16:00:39] [V] [TRT] Plugin creator already registered - ::Proposal version 1
[02/15/2022-16:00:39] [V] [TRT] Plugin creator already registered - ::ProposalLayer_TRT version 1
[02/15/2022-16:00:39] [V] [TRT] Plugin creator already registered - ::PyramidROIAlign_TRT version 1
[02/15/2022-16:00:39] [V] [TRT] Plugin creator already registered - ::ResizeNearest_TRT version 1
[02/15/2022-16:00:39] [V] [TRT] Plugin creator already registered - ::Split version 1
[02/15/2022-16:00:39] [V] [TRT] Plugin creator already registered - ::SpecialSlice_TRT version 1
[02/15/2022-16:00:39] [V] [TRT] Plugin creator already registered - ::InstanceNormalization_TRT version 1
[02/15/2022-16:00:39] [V] [TRT] Adding network input: i3 with dtype: float32, dimensions: (2, 1, 1, 2)
[02/15/2022-16:00:39] [V] [TRT] Registering tensor: i3 for ONNX tensor: i3
[02/15/2022-16:00:39] [V] [TRT] Parsing node: Constant_0 [Constant]
[02/15/2022-16:00:39] [V] [TRT] Constant_0 [Constant] inputs: 
[02/15/2022-16:00:39] [W] [TRT] onnx2trt_utils.cpp:366: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.
[02/15/2022-16:00:39] [V] [TRT] Constant_0 [Constant] outputs: [1 -> (1)[INT32]], 
[02/15/2022-16:00:39] [V] [TRT] Parsing node: Reshape_1 [Reshape]
[02/15/2022-16:00:39] [V] [TRT] Searching for input: i3
[02/15/2022-16:00:39] [V] [TRT] Searching for input: 1
[02/15/2022-16:00:39] [V] [TRT] Reshape_1 [Reshape] inputs: [i3 -> (2, 1, 1, 2)[FLOAT]], [1 -> (1)[INT32]], 
[02/15/2022-16:00:39] [V] [TRT] Registering layer: Reshape_1 for ONNX node: Reshape_1
[02/15/2022-16:00:39] [V] [TRT] Registering tensor: 2 for ONNX tensor: 2
[02/15/2022-16:00:39] [V] [TRT] Reshape_1 [Reshape] outputs: [2 -> (4)[FLOAT]], 
[02/15/2022-16:00:39] [V] [TRT] Parsing node: Constant_2 [Constant]
[02/15/2022-16:00:39] [V] [TRT] Constant_2 [Constant] inputs: 
[02/15/2022-16:00:39] [V] [TRT] Constant_2 [Constant] outputs: [3 -> (1)[INT32]], 
[02/15/2022-16:00:39] [V] [TRT] Parsing node: Reshape_3 [Reshape]
[02/15/2022-16:00:39] [V] [TRT] Searching for input: i3
[02/15/2022-16:00:39] [V] [TRT] Searching for input: 3
[02/15/2022-16:00:39] [V] [TRT] Reshape_3 [Reshape] inputs: [i3 -> (2, 1, 1, 2)[FLOAT]], [3 -> (1)[INT32]], 
[02/15/2022-16:00:39] [V] [TRT] Registering layer: Reshape_3 for ONNX node: Reshape_3
[02/15/2022-16:00:39] [V] [TRT] Registering tensor: 4 for ONNX tensor: 4
[02/15/2022-16:00:39] [V] [TRT] Reshape_3 [Reshape] outputs: [4 -> (4)[FLOAT]], 
[02/15/2022-16:00:39] [V] [TRT] Parsing node: Add_4 [Add]
[02/15/2022-16:00:39] [V] [TRT] Searching for input: 4
[02/15/2022-16:00:39] [V] [TRT] Searching for input: 4
[02/15/2022-16:00:39] [V] [TRT] Add_4 [Add] inputs: [4 -> (4)[FLOAT]], [4 -> (4)[FLOAT]], 
[02/15/2022-16:00:39] [V] [TRT] Registering layer: Add_4 for ONNX node: Add_4
[02/15/2022-16:00:39] [V] [TRT] Registering tensor: 5_0 for ONNX tensor: 5
[02/15/2022-16:00:39] [V] [TRT] Add_4 [Add] outputs: [5 -> (4)[FLOAT]], 
[02/15/2022-16:00:39] [V] [TRT] Parsing node: Add_5 [Add]
[02/15/2022-16:00:39] [V] [TRT] Searching for input: 2
[02/15/2022-16:00:39] [V] [TRT] Searching for input: 4
[02/15/2022-16:00:39] [V] [TRT] Add_5 [Add] inputs: [2 -> (4)[FLOAT]], [4 -> (4)[FLOAT]], 
[02/15/2022-16:00:39] [V] [TRT] Registering layer: Add_5 for ONNX node: Add_5
[02/15/2022-16:00:39] [V] [TRT] Registering tensor: 6_1 for ONNX tensor: 6
[02/15/2022-16:00:39] [V] [TRT] Add_5 [Add] outputs: [6 -> (4)[FLOAT]], 
[02/15/2022-16:00:39] [V] [TRT] Marking 6_1 as output: 6
[02/15/2022-16:00:39] [V] [TRT] Marking 5_0 as output: 5
[02/15/2022-16:00:39] [I] Finish parsing network model
[02/15/2022-16:00:39] [V] [TRT] Applying generic optimizations to the graph for inference.
[02/15/2022-16:00:39] [V] [TRT] Original: 4 layers
[02/15/2022-16:00:39] [V] [TRT] After dead-layer removal: 4 layers
[02/15/2022-16:00:39] [V] [TRT] After Myelin optimization: 4 layers
[02/15/2022-16:00:39] [V] [TRT] Applying ScaleNodes fusions.
[02/15/2022-16:00:39] [V] [TRT] After scale fusion: 4 layers
[02/15/2022-16:00:39] [V] [TRT] After vertical fusions: 4 layers
[02/15/2022-16:00:39] [V] [TRT] Replacing input 0 of Add_4 with 2
[02/15/2022-16:00:39] [V] [TRT] Replacing input 1 of Add_5 with 2
[02/15/2022-16:00:39] [E] Error[2]: [graphOptimizer.cpp::findOne::510] Error Code 2: Internal Error (Assertion it != v.end() failed. )
[02/15/2022-16:00:39] [E] Error[2]: [builder.cpp::buildSerializedNetwork::609] Error Code 2: Internal Error (Assertion enginePtr != nullptr failed. )
[02/15/2022-16:00:39] [E] Engine could not be created from network
[02/15/2022-16:00:39] [E] Building engine failed
[02/15/2022-16:00:39] [E] Failed to create engine from model.
[02/15/2022-16:00:39] [E] Engine set up failed
&&&& FAILED TensorRT.trtexec [TensorRT v8203] # trtexec --onnx=output.onnx --verbose

Hi,
Request you to share the ONNX model and the script if not shared already so that we can assist you better.
Alongside you can try few things:

  1. validating your model with the below snippet

check_model.py

import sys
import onnx
filename = yourONNXmodel
model = onnx.load(filename)
onnx.checker.check_model(model).
2) Try running your model with trtexec command.
https://github.com/NVIDIA/TensorRT/tree/master/samples/opensource/trtexec
In case you are still facing issue, request you to share the trtexec “”–verbose"" log for further debugging
Thanks!

Hi,

Sorry, I think the model has already been shared (see that attached file output.onnx.zip) and validated (see the end of the first piece of code). Do you need any further information or actions from my side?

Thanks!

Hi @NVES, any followups?

Hi,

We could reproduce the same error. We will get back to you.

Thank you.

1 Like