Description
Use if condition in onnx, but both true-graph and false-graph are executed.
python script generate onnx file:
modify_onnx_simp.txt (5.5 KB)
import time
import onnx
import queue
import onnxsim
import numpy as np
import onnxruntime as ort
import onnx_graphsurgeon as gs
def two_matmul():
batch_size, num_output_feat = 300, 8096
num_input_1_feat, num_input_2_feat = 1024, 4096
dst_onnx_path = 'two_matmul.onnx'
# matmul node 1
input_1 = gs.Variable(
name='input_1',
dtype=np.float32,
shape=(batch_size, num_input_1_feat)
)
matmul_1_weight = gs.Constant(
name=f'matmul_1_weight',
values=np.random.rand(num_input_1_feat, num_output_feat).astype(np.float32)
)
matmul_1_output = gs.Variable(
name='matmul_1_output',
dtype=np.float32,
shape=(batch_size, num_output_feat)
)
matmul_node_1 = gs.Node(
op='MatMul',
name='MatMul_1',
inputs=[input_1, matmul_1_weight],
outputs=[matmul_1_output]
)
# matmul node 2
input_2 = gs.Variable(
name='input_2',
dtype=np.float32,
shape=(batch_size, num_input_2_feat)
)
matmul_2_weight = gs.Constant(
name=f'matmul_2_weight',
values=np.random.rand(num_input_2_feat, num_output_feat).astype(np.float32)
)
matmul_2_output = gs.Variable(
name='matmul_2_output',
dtype=np.float32,
shape=(batch_size, num_output_feat)
)
matmul_node_2 = gs.Node(
op='MatMul',
name='MatMul_2',
inputs=[input_2, matmul_2_weight],
outputs=[matmul_2_output]
)
graph = gs.Graph(
nodes=[matmul_node_1, matmul_node_2],
inputs=[input_1, input_2],
outputs=[matmul_1_output, matmul_2_output],
name='two_matmul',
opset=11
)
dst_onnx_model = gs.export_onnx(graph)
dst_onnx_model= onnx.shape_inference.infer_shapes(dst_onnx_model, data_prop=True)
onnx.checker.check_model(dst_onnx_model)
onnx.save(dst_onnx_model, dst_onnx_path)
def two_matmul_if():
batch_size, num_output_feat = 300, 8096
num_input_1_feat, num_input_2_feat = 1024, 4096
dst_onnx_path = 'two_matmul_if.onnx'
# matmul node(branch) 1
input_1 = gs.Variable(
name='input_1',
dtype=np.float32,
shape=(batch_size, num_input_1_feat)
)
matmul_1_weight = gs.Constant(
name=f'matmul_1_weight',
values=np.random.rand(num_input_1_feat, num_output_feat).astype(np.float32)
)
matmul_1_output = gs.Variable(
name='matmul_1_output',
dtype=np.float32,
shape=(batch_size, num_output_feat)
)
matmul_node_1 = gs.Node(
op='MatMul',
name='MatMul_1',
inputs=[input_1, matmul_1_weight],
outputs=[matmul_1_output]
)
subgraph_1 = gs.Graph(
nodes=[matmul_node_1],
inputs=[input_1],
outputs=[matmul_1_output],
name='subgraph_1',
opset=13
)
# matmul node(branch) 2
input_2 = gs.Variable(
name='input_2',
dtype=np.float32,
shape=(batch_size, num_input_2_feat)
)
matmul_2_weight = gs.Constant(
name=f'matmul_2_weight',
values=np.random.rand(num_input_2_feat, num_output_feat).astype(np.float32)
)
matmul_2_output = gs.Variable(
name='matmul_2_output',
dtype=np.float32,
shape=(batch_size, num_output_feat)
)
matmul_node_2 = gs.Node(
op='MatMul',
name='MatMul_2',
inputs=[input_2, matmul_2_weight],
outputs=[matmul_2_output]
)
subgraph_2 = gs.Graph(
nodes=[matmul_node_2],
inputs=[input_2],
outputs=[matmul_2_output],
name='subgraph_2',
opset=13
)
# equal node
input_flag = gs.Variable(
name='input_flag',
dtype=np.int32,
shape=()
)
equal_target = gs.Constant(
name=f'Equal_target',
values=np.array(1, dtype=np.int32)
# values=np.array([0], dtype=np.int32).reshape((1, ))
)
equal_output = gs.Variable(
name='equal_output',
dtype=np.bool_,
shape=()
)
equal_node = gs.Node(
op='Equal',
name='Equal',
inputs=[input_flag, equal_target],
outputs=[equal_output]
)
# if node
if_output = gs.Variable(
name='if_output',
dtype=np.float32,
shape=(batch_size, num_output_feat)
)
if_node = gs.Node(
op='If',
name='If',
attrs={'then_branch': subgraph_1, 'else_branch': subgraph_2},
inputs=[equal_output],
outputs=[if_output]
)
# total graph
graph = gs.Graph(
nodes=[equal_node, if_node],
inputs=[input_1, input_2, input_flag],
outputs=[if_output],
name='two_matmul_if',
opset=13
)
dst_onnx_model = gs.export_onnx(graph)
dst_onnx_model= onnx.shape_inference.infer_shapes(dst_onnx_model, data_prop=True)
onnx.checker.check_model(dst_onnx_model)
onnx.save(dst_onnx_model, dst_onnx_path)
if __name__ == '__main__':
two_matmul()
two_matmul_if()
Environment
TensorRT Version: 10.2.0
GPU Type: A30
Nvidia Driver Version: 470.57.02
CUDA Version: 11.4
CUDNN Version:
Operating System + Version: centos7
Python Version (if applicable):
TensorFlow Version (if applicable):
PyTorch Version (if applicable):
Baremetal or Container (if container which image + tag):
Relevant Files
Steps To Reproduce
generate onnx model two_matmul_if.onnx
python3 modify_onnx_simp.py
generate two_matmul_if.engine
trtexec --builderOptimizationLevel=3 --maxAuxStreams=0 --useCudaGraph --onnx=two_matmul_if.onnx --saveEngine=two_matmul_if.engine
generate flag input
np.array([0], dtype=np.int32).reshape((1, )).tofile('flag_0.bin')
np.array([1], dtype=np.int32).reshape((1, )).tofile('flag_1.bin')
profile engine with flag_0.bin, expect true-branch(MatMul_1) run.
/trtexec --warmUp=100 --iterations=10000 --useCudaGraph --separateProfileRun --dumpProfile --useSpinWait --noDataTransfers --loadEngine=two_matmul_if.engine --loadInputs='input_flag':flag_0.bin
profile engine with flag_1.bin, expect false-branch(MatMul_2) run.
/trtexec --warmUp=100 --iterations=10000 --useCudaGraph --separateProfileRun --dumpProfile --useSpinWait --noDataTransfers --loadEngine=two_matmul_if.engine --loadInputs='input_flag':flag_1.bin
input flag_0.bin and flag_1.bin profile is same, both MatMul_1 and MatMul_2 run.
[07/15/2024-20:17:13] [I] Time(ms) Avg.(ms) Median(ms) Time(%) Layer
[07/15/2024-20:17:13] [I] 23.00 0.0023 0.0022 0.3 __mye186_hc_init_myl0_0
[07/15/2024-20:17:13] [I] 1313.88 0.1301 0.1297 18.1 MatMul_1_myl0_1
[07/15/2024-20:17:13] [I] 5517.65 0.5462 0.5486 76.1 MatMul_2_myl0_2
[07/15/2024-20:17:13] [I] 140.25 0.0139 0.0129 1.9 copy_d2h___mye161_myl0_3
[07/15/2024-20:17:13] [I] 23.69 0.0023 0.0023 0.3 __mye33cbr_myl0_4
[07/15/2024-20:17:13] [I] 211.70 0.0210 0.0208 2.9 __myl_Mov_myl0_5
[07/15/2024-20:17:13] [I] 22.79 0.0023 0.0022 0.3 jmp__mye39_myl0_6
[07/15/2024-20:17:13] [I] 7252.96 0.7180 0.7177 100.0 Total
Please include:
- Exact steps/commands to build your repro
- Exact steps/commands to run your repro
- Full traceback of errors encountered