If conditional both true-graph and false-graph are executed

Description

Use if condition in onnx, but both true-graph and false-graph are executed.
python script generate onnx file:
modify_onnx_simp.txt (5.5 KB)

import time
import onnx
import queue
import onnxsim
import numpy as np
import onnxruntime as ort
import onnx_graphsurgeon as gs


def two_matmul():
    batch_size, num_output_feat = 300, 8096
    num_input_1_feat, num_input_2_feat = 1024, 4096
    dst_onnx_path = 'two_matmul.onnx'

    # matmul node 1
    input_1 = gs.Variable(
            name='input_1',
            dtype=np.float32,
            shape=(batch_size, num_input_1_feat)
            )
    matmul_1_weight = gs.Constant(
            name=f'matmul_1_weight',
            values=np.random.rand(num_input_1_feat, num_output_feat).astype(np.float32)
            )
    matmul_1_output = gs.Variable(
            name='matmul_1_output',
            dtype=np.float32,
            shape=(batch_size, num_output_feat)
            )
    matmul_node_1 = gs.Node(
            op='MatMul',
            name='MatMul_1',
            inputs=[input_1, matmul_1_weight],
            outputs=[matmul_1_output]
            )

    # matmul node 2
    input_2 = gs.Variable(
            name='input_2',
            dtype=np.float32,
            shape=(batch_size, num_input_2_feat)
            )
    matmul_2_weight = gs.Constant(
            name=f'matmul_2_weight',
            values=np.random.rand(num_input_2_feat, num_output_feat).astype(np.float32)
            )
    matmul_2_output = gs.Variable(
            name='matmul_2_output',
            dtype=np.float32,
            shape=(batch_size, num_output_feat)
            )
    matmul_node_2 = gs.Node(
            op='MatMul',
            name='MatMul_2',
            inputs=[input_2, matmul_2_weight],
            outputs=[matmul_2_output]
            )
    graph = gs.Graph(
            nodes=[matmul_node_1, matmul_node_2],
            inputs=[input_1, input_2],
            outputs=[matmul_1_output, matmul_2_output],
            name='two_matmul',
            opset=11
            )

    dst_onnx_model = gs.export_onnx(graph)
    dst_onnx_model= onnx.shape_inference.infer_shapes(dst_onnx_model, data_prop=True)
    onnx.checker.check_model(dst_onnx_model)
    onnx.save(dst_onnx_model, dst_onnx_path)


def two_matmul_if():
    batch_size, num_output_feat = 300, 8096
    num_input_1_feat, num_input_2_feat = 1024, 4096
    dst_onnx_path = 'two_matmul_if.onnx'

    # matmul node(branch) 1
    input_1 = gs.Variable(
            name='input_1',
            dtype=np.float32,
            shape=(batch_size, num_input_1_feat)
            )
    matmul_1_weight = gs.Constant(
            name=f'matmul_1_weight',
            values=np.random.rand(num_input_1_feat, num_output_feat).astype(np.float32)
            )
    matmul_1_output = gs.Variable(
            name='matmul_1_output',
            dtype=np.float32,
            shape=(batch_size, num_output_feat)
            )
    matmul_node_1 = gs.Node(
            op='MatMul',
            name='MatMul_1',
            inputs=[input_1, matmul_1_weight],
            outputs=[matmul_1_output]
            )
    subgraph_1 = gs.Graph(
            nodes=[matmul_node_1],
            inputs=[input_1],
            outputs=[matmul_1_output],
            name='subgraph_1',
            opset=13
            )

    # matmul node(branch) 2
    input_2 = gs.Variable(
            name='input_2',
            dtype=np.float32,
            shape=(batch_size, num_input_2_feat)
            )
    matmul_2_weight = gs.Constant(
            name=f'matmul_2_weight',
            values=np.random.rand(num_input_2_feat, num_output_feat).astype(np.float32)
            )
    matmul_2_output = gs.Variable(
            name='matmul_2_output',
            dtype=np.float32,
            shape=(batch_size, num_output_feat)
            )
    matmul_node_2 = gs.Node(
            op='MatMul',
            name='MatMul_2',
            inputs=[input_2, matmul_2_weight],
            outputs=[matmul_2_output]
            )
    subgraph_2 = gs.Graph(
            nodes=[matmul_node_2],
            inputs=[input_2],
            outputs=[matmul_2_output],
            name='subgraph_2',
            opset=13
            )

    # equal node
    input_flag = gs.Variable(
            name='input_flag',
            dtype=np.int32,
            shape=()
            )
    equal_target = gs.Constant(
            name=f'Equal_target',
            values=np.array(1, dtype=np.int32)
            # values=np.array([0], dtype=np.int32).reshape((1, ))
            )
    equal_output = gs.Variable(
            name='equal_output',
            dtype=np.bool_,
            shape=()
            )
    equal_node = gs.Node(
            op='Equal',
            name='Equal',
            inputs=[input_flag, equal_target],
            outputs=[equal_output]
            )

    # if node
    if_output = gs.Variable(
            name='if_output',
            dtype=np.float32,
            shape=(batch_size, num_output_feat)
            )
    if_node = gs.Node(
            op='If',
            name='If',
            attrs={'then_branch': subgraph_1, 'else_branch': subgraph_2},
            inputs=[equal_output],
            outputs=[if_output]
            )

    # total graph
    graph = gs.Graph(
            nodes=[equal_node, if_node],
            inputs=[input_1, input_2, input_flag],
            outputs=[if_output],
            name='two_matmul_if',
            opset=13
            )

    dst_onnx_model = gs.export_onnx(graph)
    dst_onnx_model= onnx.shape_inference.infer_shapes(dst_onnx_model, data_prop=True)
    onnx.checker.check_model(dst_onnx_model)
    onnx.save(dst_onnx_model, dst_onnx_path)


if __name__ == '__main__':
    two_matmul()
    two_matmul_if()

Environment

TensorRT Version: 10.2.0
GPU Type: A30
Nvidia Driver Version: 470.57.02
CUDA Version: 11.4
CUDNN Version:
Operating System + Version: centos7
Python Version (if applicable):
TensorFlow Version (if applicable):
PyTorch Version (if applicable):
Baremetal or Container (if container which image + tag):

Relevant Files

Please attach or include links to any models, data, files, or scripts necessary to reproduce your issue. (Github repo, Google Drive, Dropbox, etc.)

Steps To Reproduce

generate onnx model two_matmul_if.onnx

python3 modify_onnx_simp.py

generate two_matmul_if.engine

trtexec --builderOptimizationLevel=3 --maxAuxStreams=0 --useCudaGraph --onnx=two_matmul_if.onnx --saveEngine=two_matmul_if.engine 

generate flag input

np.array([0], dtype=np.int32).reshape((1, )).tofile('flag_0.bin')
np.array([1], dtype=np.int32).reshape((1, )).tofile('flag_1.bin')

profile engine with flag_0.bin, expect true-branch(MatMul_1) run.

/trtexec --warmUp=100  --iterations=10000 --useCudaGraph  --separateProfileRun  --dumpProfile --useSpinWait --noDataTransfers --loadEngine=two_matmul_if.engine --loadInputs='input_flag':flag_0.bin

profile engine with flag_1.bin, expect false-branch(MatMul_2) run.

/trtexec --warmUp=100  --iterations=10000 --useCudaGraph  --separateProfileRun  --dumpProfile --useSpinWait --noDataTransfers --loadEngine=two_matmul_if.engine --loadInputs='input_flag':flag_1.bin

input flag_0.bin and flag_1.bin profile is same, both MatMul_1 and MatMul_2 run.

[07/15/2024-20:17:13] [I]    Time(ms)     Avg.(ms)   Median(ms)   Time(%)   Layer
[07/15/2024-20:17:13] [I]       23.00       0.0023       0.0022       0.3   __mye186_hc_init_myl0_0
[07/15/2024-20:17:13] [I]     1313.88       0.1301       0.1297      18.1   MatMul_1_myl0_1
[07/15/2024-20:17:13] [I]     5517.65       0.5462       0.5486      76.1   MatMul_2_myl0_2
[07/15/2024-20:17:13] [I]      140.25       0.0139       0.0129       1.9   copy_d2h___mye161_myl0_3
[07/15/2024-20:17:13] [I]       23.69       0.0023       0.0023       0.3   __mye33cbr_myl0_4
[07/15/2024-20:17:13] [I]      211.70       0.0210       0.0208       2.9   __myl_Mov_myl0_5
[07/15/2024-20:17:13] [I]       22.79       0.0023       0.0022       0.3   jmp__mye39_myl0_6
[07/15/2024-20:17:13] [I]     7252.96       0.7180       0.7177     100.0   Total

Please include:

  • Exact steps/commands to build your repro
  • Exact steps/commands to run your repro
  • Full traceback of errors encountered