Description
Hi.
I am working on converting pytorch ‘mobilevit’ model to tensorrt inference engine.
I try to export pytorch model to onnx and then run onnx-simplifier. I succeed to this point.
However, when I try to convert onnx-simplified model to tenssorrt I got error.
The error related to multi-head attention block of mobilevit model.
error message I got.
root/gpgpu/MachineLearning/myelin/src/compiler/optimizer/kqv_gemm_split.cpp:350: void myelin::ir::kqv_split_pattern_t::check_transpose(): Assertion `in_dims.size() == 3' failed.
I tested above mentioned converting step on several GPU.
Titan xp, A100, titan RTX.
All of my tries failed on above gpu.
I used NGC pytorch image 22.03-py3.
image: nvcr.io/nvidia/pytorch:22.03-py3
Help me why this happen and how to solve it.
I post my code below.
Entire log message
[04/26/2022-21:37:56] [TRT] [I] [MemUsageChange] Init CUDA: CPU +426, GPU +0, now: CPU 827, GPU 3325 (MiB)
[04/26/2022-21:37:56] [TRT] [I] [MemUsageSnapshot] Begin constructing builder kernel library: CPU 827 MiB, GPU 3325 MiB
[04/26/2022-21:37:56] [TRT] [I] [MemUsageSnapshot] End constructing builder kernel library: CPU 1044 MiB, GPU 3397 MiB
[04/26/2022-21:37:56] [TRT] [I] ----------------------------------------------------------------
[04/26/2022-21:37:56] [TRT] [I] Input filename: mvit.onnx
[04/26/2022-21:37:56] [TRT] [I] ONNX IR version: 0.0.7
[04/26/2022-21:37:56] [TRT] [I] Opset version: 13
[04/26/2022-21:37:56] [TRT] [I] Producer name: pytorch
[04/26/2022-21:37:56] [TRT] [I] Producer version: 1.11.0
[04/26/2022-21:37:56] [TRT] [I] Domain:
[04/26/2022-21:37:56] [TRT] [I] Model version: 0
[04/26/2022-21:37:56] [TRT] [I] Doc string:
[04/26/2022-21:37:56] [TRT] [I] ----------------------------------------------------------------
[04/26/2022-21:37:56] [TRT] [V] Registered plugin creator - ::BatchTilePlugin_TRT version 1
[04/26/2022-21:37:56] [TRT] [V] Registered plugin creator - ::BatchedNMS_TRT version 1
[04/26/2022-21:37:56] [TRT] [V] Registered plugin creator - ::BatchedNMSDynamic_TRT version 1
[04/26/2022-21:37:56] [TRT] [V] Registered plugin creator - ::CoordConvAC version 1
[04/26/2022-21:37:56] [TRT] [V] Registered plugin creator - ::CropAndResize version 1
[04/26/2022-21:37:56] [TRT] [V] Registered plugin creator - ::CropAndResizeDynamic version 1
[04/26/2022-21:37:56] [TRT] [V] Registered plugin creator - ::DetectionLayer_TRT version 1
[04/26/2022-21:37:56] [TRT] [V] Registered plugin creator - ::EfficientNMS_TRT version 1
[04/26/2022-21:37:56] [TRT] [V] Registered plugin creator - ::EfficientNMS_ONNX_TRT version 1
[04/26/2022-21:37:56] [TRT] [V] Registered plugin creator - ::EfficientNMS_TFTRT_TRT version 1
[04/26/2022-21:37:56] [TRT] [V] Registered plugin creator - ::FlattenConcat_TRT version 1
[04/26/2022-21:37:56] [TRT] [V] Registered plugin creator - ::GenerateDetection_TRT version 1
[04/26/2022-21:37:56] [TRT] [V] Registered plugin creator - ::GridAnchor_TRT version 1
[04/26/2022-21:37:56] [TRT] [V] Registered plugin creator - ::GridAnchorRect_TRT version 1
[04/26/2022-21:37:56] [TRT] [V] Registered plugin creator - ::InstanceNormalization_TRT version 1
[04/26/2022-21:37:56] [TRT] [V] Registered plugin creator - ::LReLU_TRT version 1
[04/26/2022-21:37:56] [TRT] [V] Registered plugin creator - ::MultilevelCropAndResize_TRT version 1
[04/26/2022-21:37:56] [TRT] [V] Registered plugin creator - ::MultilevelProposeROI_TRT version 1
[04/26/2022-21:37:56] [TRT] [V] Registered plugin creator - ::NMS_TRT version 1
[04/26/2022-21:37:56] [TRT] [V] Registered plugin creator - ::NMSDynamic_TRT version 1
[04/26/2022-21:37:56] [TRT] [V] Registered plugin creator - ::Normalize_TRT version 1
[04/26/2022-21:37:56] [TRT] [V] Registered plugin creator - ::PriorBox_TRT version 1
[04/26/2022-21:37:56] [TRT] [V] Registered plugin creator - ::ProposalLayer_TRT version 1
[04/26/2022-21:37:56] [TRT] [V] Registered plugin creator - ::Proposal version 1
[04/26/2022-21:37:56] [TRT] [V] Registered plugin creator - ::ProposalDynamic version 1
[04/26/2022-21:37:56] [TRT] [V] Registered plugin creator - ::PyramidROIAlign_TRT version 1
[04/26/2022-21:37:56] [TRT] [V] Registered plugin creator - ::Region_TRT version 1
[04/26/2022-21:37:56] [TRT] [V] Registered plugin creator - ::Reorg_TRT version 1
[04/26/2022-21:37:56] [TRT] [V] Registered plugin creator - ::ResizeNearest_TRT version 1
[04/26/2022-21:37:56] [TRT] [V] Registered plugin creator - ::RPROI_TRT version 1
[04/26/2022-21:37:56] [TRT] [V] Registered plugin creator - ::ScatterND version 1
[04/26/2022-21:37:56] [TRT] [V] Registered plugin creator - ::SpecialSlice_TRT version 1
[04/26/2022-21:37:56] [TRT] [V] Registered plugin creator - ::Split version 1
[04/26/2022-21:37:56] [TRT] [V] Adding network input: input with dtype: float32, dimensions: (1, 8, 1024, 120)
[04/26/2022-21:37:56] [TRT] [V] Registering tensor: input for ONNX tensor: input
[04/26/2022-21:37:56] [TRT] [V] Importing initializer: to_out.0.bias
[04/26/2022-21:37:56] [TRT] [V] Importing initializer: 81
[04/26/2022-21:37:56] [TRT] [V] Importing initializer: 87
[04/26/2022-21:37:56] [TRT] [W] parsers/onnx/onnx2trt_utils.cpp:364: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.
[04/26/2022-21:37:56] [TRT] [V] Importing initializer: 104
[04/26/2022-21:37:56] [TRT] [V] Importing initializer: 105
[04/26/2022-21:37:56] [TRT] [V] Importing initializer: 6
[04/26/2022-21:37:56] [TRT] [V] Importing initializer: 59
[04/26/2022-21:37:56] [TRT] [V] Parsing node: MatMul_0 [MatMul]
[04/26/2022-21:37:56] [TRT] [V] Searching for input: input
[04/26/2022-21:37:56] [TRT] [V] Searching for input: 81
[04/26/2022-21:37:56] [TRT] [V] MatMul_0 [MatMul] inputs: [input -> (1, 8, 1024, 120)[FLOAT]], [81 -> (120, 96)[FLOAT]],
[04/26/2022-21:37:56] [TRT] [V] Registering layer: 81 for ONNX node: 81
[04/26/2022-21:37:56] [TRT] [V] Registering layer: MatMul_0 for ONNX node: MatMul_0
[04/26/2022-21:37:56] [TRT] [I] MatMul_0: broadcasting input1 to make tensors conform, dims(input0)=[1,8,1024,120][NONE] dims(input1)=[1,1,120,96][NONE].
[04/26/2022-21:37:56] [TRT] [V] Registering tensor: tensor for ONNX tensor: tensor
[04/26/2022-21:37:56] [TRT] [V] MatMul_0 [MatMul] outputs: [tensor -> (1, 8, 1024, 96)[FLOAT]],
[04/26/2022-21:37:56] [TRT] [V] Parsing node: Split_2 [Split]
[04/26/2022-21:37:56] [TRT] [V] Searching for input: tensor
[04/26/2022-21:37:56] [TRT] [V] Searching for input: 6
[04/26/2022-21:37:56] [TRT] [V] Split_2 [Split] inputs: [tensor -> (1, 8, 1024, 96)[FLOAT]], [6 -> (3)[INT32]],
[04/26/2022-21:37:56] [TRT] [V] Registering layer: Split_2 for ONNX node: Split_2
[04/26/2022-21:37:56] [TRT] [V] Registering layer: Split_2_0 for ONNX node: Split_2
[04/26/2022-21:37:56] [TRT] [V] Registering layer: Split_2_1 for ONNX node: Split_2
[04/26/2022-21:37:56] [TRT] [V] Registering tensor: q for ONNX tensor: q
[04/26/2022-21:37:56] [TRT] [V] Registering tensor: k for ONNX tensor: k
[04/26/2022-21:37:56] [TRT] [V] Registering tensor: v for ONNX tensor: v
[04/26/2022-21:37:56] [TRT] [V] Split_2 [Split] outputs: [q -> (1, 8, 1024, 32)[FLOAT]], [k -> (1, 8, 1024, 32)[FLOAT]], [v -> (1, 8, 1024, 32)[FLOAT]],
[04/26/2022-21:37:56] [TRT] [V] Parsing node: Reshape_3 [Reshape]
[04/26/2022-21:37:56] [TRT] [V] Searching for input: q
[04/26/2022-21:37:56] [TRT] [V] Searching for input: 87
[04/26/2022-21:37:56] [TRT] [V] Reshape_3 [Reshape] inputs: [q -> (1, 8, 1024, 32)[FLOAT]], [87 -> (5)[INT32]],
[04/26/2022-21:37:56] [TRT] [V] Registering layer: Reshape_3 for ONNX node: Reshape_3
[04/26/2022-21:37:56] [TRT] [V] Registering tensor: 26 for ONNX tensor: 26
[04/26/2022-21:37:56] [TRT] [V] Reshape_3 [Reshape] outputs: [26 -> (1, 8, 1024, 4, 8)[FLOAT]],
[04/26/2022-21:37:56] [TRT] [V] Parsing node: Transpose_4 [Transpose]
[04/26/2022-21:37:56] [TRT] [V] Searching for input: 26
[04/26/2022-21:37:56] [TRT] [V] Transpose_4 [Transpose] inputs: [26 -> (1, 8, 1024, 4, 8)[FLOAT]],
[04/26/2022-21:37:56] [TRT] [V] Registering layer: Transpose_4 for ONNX node: Transpose_4
[04/26/2022-21:37:56] [TRT] [V] Registering tensor: 27 for ONNX tensor: 27
[04/26/2022-21:37:56] [TRT] [V] Transpose_4 [Transpose] outputs: [27 -> (1, 8, 4, 1024, 8)[FLOAT]],
[04/26/2022-21:37:56] [TRT] [V] Parsing node: Reshape_5 [Reshape]
[04/26/2022-21:37:56] [TRT] [V] Searching for input: k
[04/26/2022-21:37:56] [TRT] [V] Searching for input: 87
[04/26/2022-21:37:56] [TRT] [V] Reshape_5 [Reshape] inputs: [k -> (1, 8, 1024, 32)[FLOAT]], [87 -> (5)[INT32]],
[04/26/2022-21:37:56] [TRT] [V] Registering layer: Reshape_5 for ONNX node: Reshape_5
[04/26/2022-21:37:56] [TRT] [V] Registering tensor: 41 for ONNX tensor: 41
[04/26/2022-21:37:56] [TRT] [V] Reshape_5 [Reshape] outputs: [41 -> (1, 8, 1024, 4, 8)[FLOAT]],
[04/26/2022-21:37:56] [TRT] [V] Parsing node: Reshape_6 [Reshape]
[04/26/2022-21:37:56] [TRT] [V] Searching for input: v
[04/26/2022-21:37:56] [TRT] [V] Searching for input: 87
[04/26/2022-21:37:56] [TRT] [V] Reshape_6 [Reshape] inputs: [v -> (1, 8, 1024, 32)[FLOAT]], [87 -> (5)[INT32]],
[04/26/2022-21:37:56] [TRT] [V] Registering layer: Reshape_6 for ONNX node: Reshape_6
[04/26/2022-21:37:56] [TRT] [V] Registering tensor: 55 for ONNX tensor: 55
[04/26/2022-21:37:56] [TRT] [V] Reshape_6 [Reshape] outputs: [55 -> (1, 8, 1024, 4, 8)[FLOAT]],
[04/26/2022-21:37:56] [TRT] [V] Parsing node: Transpose_7 [Transpose]
[04/26/2022-21:37:56] [TRT] [V] Searching for input: 55
[04/26/2022-21:37:56] [TRT] [V] Transpose_7 [Transpose] inputs: [55 -> (1, 8, 1024, 4, 8)[FLOAT]],
[04/26/2022-21:37:56] [TRT] [V] Registering layer: Transpose_7 for ONNX node: Transpose_7
[04/26/2022-21:37:56] [TRT] [V] Registering tensor: 56 for ONNX tensor: 56
[04/26/2022-21:37:56] [TRT] [V] Transpose_7 [Transpose] outputs: [56 -> (1, 8, 4, 1024, 8)[FLOAT]],
[04/26/2022-21:37:56] [TRT] [V] Parsing node: Transpose_8 [Transpose]
[04/26/2022-21:37:56] [TRT] [V] Searching for input: 41
[04/26/2022-21:37:56] [TRT] [V] Transpose_8 [Transpose] inputs: [41 -> (1, 8, 1024, 4, 8)[FLOAT]],
[04/26/2022-21:37:56] [TRT] [V] Registering layer: Transpose_8 for ONNX node: Transpose_8
[04/26/2022-21:37:56] [TRT] [V] Registering tensor: 57 for ONNX tensor: 57
[04/26/2022-21:37:56] [TRT] [V] Transpose_8 [Transpose] outputs: [57 -> (1, 8, 4, 8, 1024)[FLOAT]],
[04/26/2022-21:37:56] [TRT] [V] Parsing node: MatMul_9 [MatMul]
[04/26/2022-21:37:56] [TRT] [V] Searching for input: 27
[04/26/2022-21:37:56] [TRT] [V] Searching for input: 57
[04/26/2022-21:37:56] [TRT] [V] MatMul_9 [MatMul] inputs: [27 -> (1, 8, 4, 1024, 8)[FLOAT]], [57 -> (1, 8, 4, 8, 1024)[FLOAT]],
[04/26/2022-21:37:56] [TRT] [V] Registering layer: MatMul_9 for ONNX node: MatMul_9
[04/26/2022-21:37:56] [TRT] [V] Registering tensor: 58 for ONNX tensor: 58
[04/26/2022-21:37:56] [TRT] [V] MatMul_9 [MatMul] outputs: [58 -> (1, 8, 4, 1024, 1024)[FLOAT]],
[04/26/2022-21:37:56] [TRT] [V] Parsing node: Mul_11 [Mul]
[04/26/2022-21:37:56] [TRT] [V] Searching for input: 58
[04/26/2022-21:37:56] [TRT] [V] Searching for input: 59
[04/26/2022-21:37:56] [TRT] [V] Mul_11 [Mul] inputs: [58 -> (1, 8, 4, 1024, 1024)[FLOAT]], [59 -> ()[FLOAT]],
[04/26/2022-21:37:56] [TRT] [V] Registering layer: 59 for ONNX node: 59
[04/26/2022-21:37:56] [TRT] [V] Registering layer: Mul_11 for ONNX node: Mul_11
[04/26/2022-21:37:56] [TRT] [V] Registering tensor: input.1 for ONNX tensor: input.1
[04/26/2022-21:37:56] [TRT] [V] Mul_11 [Mul] outputs: [input.1 -> (1, 8, 4, 1024, 1024)[FLOAT]],
[04/26/2022-21:37:56] [TRT] [V] Parsing node: Softmax_12 [Softmax]
[04/26/2022-21:37:56] [TRT] [V] Searching for input: input.1
[04/26/2022-21:37:56] [TRT] [V] Softmax_12 [Softmax] inputs: [input.1 -> (1, 8, 4, 1024, 1024)[FLOAT]],
[04/26/2022-21:37:56] [TRT] [V] Registering layer: Softmax_12 for ONNX node: Softmax_12
[04/26/2022-21:37:56] [TRT] [V] Registering tensor: 61 for ONNX tensor: 61
[04/26/2022-21:37:56] [TRT] [V] Softmax_12 [Softmax] outputs: [61 -> (1, 8, 4, 1024, 1024)[FLOAT]],
[04/26/2022-21:37:56] [TRT] [V] Parsing node: MatMul_13 [MatMul]
[04/26/2022-21:37:56] [TRT] [V] Searching for input: 61
[04/26/2022-21:37:56] [TRT] [V] Searching for input: 56
[04/26/2022-21:37:56] [TRT] [V] MatMul_13 [MatMul] inputs: [61 -> (1, 8, 4, 1024, 1024)[FLOAT]], [56 -> (1, 8, 4, 1024, 8)[FLOAT]],
[04/26/2022-21:37:56] [TRT] [V] Registering layer: MatMul_13 for ONNX node: MatMul_13
[04/26/2022-21:37:56] [TRT] [V] Registering tensor: out for ONNX tensor: out
[04/26/2022-21:37:56] [TRT] [V] MatMul_13 [MatMul] outputs: [out -> (1, 8, 4, 1024, 8)[FLOAT]],
[04/26/2022-21:37:56] [TRT] [V] Parsing node: Transpose_14 [Transpose]
[04/26/2022-21:37:56] [TRT] [V] Searching for input: out
[04/26/2022-21:37:56] [TRT] [V] Transpose_14 [Transpose] inputs: [out -> (1, 8, 4, 1024, 8)[FLOAT]],
[04/26/2022-21:37:56] [TRT] [V] Registering layer: Transpose_14 for ONNX node: Transpose_14
[04/26/2022-21:37:56] [TRT] [V] Registering tensor: 66 for ONNX tensor: 66
[04/26/2022-21:37:56] [TRT] [V] Transpose_14 [Transpose] outputs: [66 -> (1, 8, 1024, 4, 8)[FLOAT]],
[04/26/2022-21:37:56] [TRT] [V] Parsing node: Reshape_15 [Reshape]
[04/26/2022-21:37:56] [TRT] [V] Searching for input: 66
[04/26/2022-21:37:56] [TRT] [V] Searching for input: 104
[04/26/2022-21:37:56] [TRT] [V] Reshape_15 [Reshape] inputs: [66 -> (1, 8, 1024, 4, 8)[FLOAT]], [104 -> (4)[INT32]],
[04/26/2022-21:37:56] [TRT] [V] Registering layer: Reshape_15 for ONNX node: Reshape_15
[04/26/2022-21:37:56] [TRT] [V] Registering tensor: input.4 for ONNX tensor: input.4
[04/26/2022-21:37:56] [TRT] [V] Reshape_15 [Reshape] outputs: [input.4 -> (1, 8, 1024, 32)[FLOAT]],
[04/26/2022-21:37:56] [TRT] [V] Parsing node: MatMul_16 [MatMul]
[04/26/2022-21:37:56] [TRT] [V] Searching for input: input.4
[04/26/2022-21:37:56] [TRT] [V] Searching for input: 105
[04/26/2022-21:37:56] [TRT] [V] MatMul_16 [MatMul] inputs: [input.4 -> (1, 8, 1024, 32)[FLOAT]], [105 -> (32, 120)[FLOAT]],
[04/26/2022-21:37:56] [TRT] [V] Registering layer: 105 for ONNX node: 105
[04/26/2022-21:37:56] [TRT] [V] Registering layer: MatMul_16 for ONNX node: MatMul_16
[04/26/2022-21:37:56] [TRT] [I] MatMul_16: broadcasting input1 to make tensors conform, dims(input0)=[1,8,1024,32][NONE] dims(input1)=[1,1,32,120][NONE].
[04/26/2022-21:37:56] [TRT] [V] Registering tensor: 79 for ONNX tensor: 79
[04/26/2022-21:37:56] [TRT] [V] MatMul_16 [MatMul] outputs: [79 -> (1, 8, 1024, 120)[FLOAT]],
[04/26/2022-21:37:56] [TRT] [V] Parsing node: Add_17 [Add]
[04/26/2022-21:37:56] [TRT] [V] Searching for input: to_out.0.bias
[04/26/2022-21:37:56] [TRT] [V] Searching for input: 79
[04/26/2022-21:37:56] [TRT] [V] Add_17 [Add] inputs: [to_out.0.bias -> (120)[FLOAT]], [79 -> (1, 8, 1024, 120)[FLOAT]],
[04/26/2022-21:37:56] [TRT] [V] Registering layer: to_out.0.bias for ONNX node: to_out.0.bias
[04/26/2022-21:37:56] [TRT] [V] Registering layer: Add_17 for ONNX node: Add_17
[04/26/2022-21:37:56] [TRT] [V] Registering tensor: output_2 for ONNX tensor: output
[04/26/2022-21:37:56] [TRT] [V] Add_17 [Add] outputs: [output -> (1, 8, 1024, 120)[FLOAT]],
[04/26/2022-21:37:56] [TRT] [V] Marking output_2 as output: output
/home/mvit/trt_convert.py:86: DeprecationWarning: Use build_serialized_network instead.
trt_engine = builder.build_engine(network, config)
[04/26/2022-21:37:56] [TRT] [I] MatMul_0: broadcasting input1 to make tensors conform, dims(input0)=[1,8,1024,120][NONE] dims(input1)=[1,1,120,96][NONE].
[04/26/2022-21:37:56] [TRT] [I] MatMul_16: broadcasting input1 to make tensors conform, dims(input0)=[1,8,1024,32][NONE] dims(input1)=[1,1,32,120][NONE].
[04/26/2022-21:37:56] [TRT] [V] Applying generic optimizations to the graph for inference.
[04/26/2022-21:37:56] [TRT] [V] Original: 27 layers
[04/26/2022-21:37:56] [TRT] [V] After dead-layer removal: 27 layers
[04/26/2022-21:37:56] [TRT] [V] Running: ConstShuffleFusion
[04/26/2022-21:37:56] [TRT] [V] ConstShuffleFusion: Fusing 81 with (Unnamed Layer* 1) [Shuffle]
[04/26/2022-21:37:56] [TRT] [V] Running: ShuffleShuffleFusion
[04/26/2022-21:37:56] [TRT] [V] ShuffleShuffleFusion: Fusing Reshape_3 with Transpose_4
[04/26/2022-21:37:56] [TRT] [V] Running: ShuffleShuffleFusion
[04/26/2022-21:37:56] [TRT] [V] ShuffleShuffleFusion: Fusing Reshape_5 with Transpose_8
[04/26/2022-21:37:56] [TRT] [V] Running: ShuffleShuffleFusion
[04/26/2022-21:37:56] [TRT] [V] ShuffleShuffleFusion: Fusing Reshape_6 with Transpose_7
[04/26/2022-21:37:56] [TRT] [V] Running: ConstShuffleFusion
[04/26/2022-21:37:56] [TRT] [V] ConstShuffleFusion: Fusing 59 with (Unnamed Layer* 14) [Shuffle]
[04/26/2022-21:37:56] [TRT] [V] Running: ShuffleErasure
[04/26/2022-21:37:56] [TRT] [V] Removing (Unnamed Layer* 17) [Shuffle]
[04/26/2022-21:37:56] [TRT] [V] Running: ShuffleShuffleFusion
[04/26/2022-21:37:56] [TRT] [V] ShuffleShuffleFusion: Fusing Transpose_14 with Reshape_15
[04/26/2022-21:37:56] [TRT] [V] Running: ConstShuffleFusion
[04/26/2022-21:37:56] [TRT] [V] ConstShuffleFusion: Fusing 105 with (Unnamed Layer* 22) [Shuffle]
[04/26/2022-21:37:56] [TRT] [V] Running: ConstShuffleFusion
[04/26/2022-21:37:56] [TRT] [V] ConstShuffleFusion: Fusing to_out.0.bias with (Unnamed Layer* 25) [Shuffle]
[04/26/2022-21:37:56] [TRT] [V] Found Split_2 to be part of self-attention pattern.
[04/26/2022-21:37:56] [TRT] [V] Found Split_2_0 to be part of self-attention pattern.
[04/26/2022-21:37:56] [TRT] [V] Found Split_2_1 to be part of self-attention pattern.
[04/26/2022-21:37:56] [TRT] [V] Found MatMul_9 to be part of self-attention pattern.
[04/26/2022-21:37:56] [TRT] [V] Found Softmax_12 to be part of self-attention pattern.
[04/26/2022-21:37:56] [TRT] [V] Found MatMul_13 to be part of self-attention pattern.
[04/26/2022-21:37:56] [TRT] [V] Found MatMul_0 to be part of self-attention pattern.
[04/26/2022-21:37:56] [TRT] [V] Found and reassigned Myelin backends for Self-Attention nodes
[04/26/2022-21:37:56] [TRT] [V] After Myelin optimization: 1 layers
[04/26/2022-21:37:56] [TRT] [V] Applying ScaleNodes fusions.
[04/26/2022-21:37:56] [TRT] [V] After scale fusion: 1 layers
[04/26/2022-21:37:56] [TRT] [V] After vertical fusions: 1 layers
[04/26/2022-21:37:56] [TRT] [V] After dupe layer removal: 1 layers
[04/26/2022-21:37:56] [TRT] [V] After final dead-layer removal: 1 layers
[04/26/2022-21:37:56] [TRT] [V] After tensor merging: 1 layers
[04/26/2022-21:37:56] [TRT] [V] After concat removal: 1 layers
[04/26/2022-21:37:56] [TRT] [V] Graph construction and optimization completed in 0.0030201 seconds.
[04/26/2022-21:37:57] [TRT] [V] Using cublasLt as a tactic source
[04/26/2022-21:37:57] [TRT] [I] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +809, GPU +350, now: CPU 1853, GPU 3747 (MiB)
[04/26/2022-21:37:57] [TRT] [V] Using cuDNN as a tactic source
[04/26/2022-21:37:57] [TRT] [I] [MemUsageChange] Init cuDNN: CPU +126, GPU +58, now: CPU 1979, GPU 3805 (MiB)
[04/26/2022-21:37:57] [TRT] [I] Local timing cache in use. Profiling results in this builder pass will not be stored.
[04/26/2022-21:37:57] [TRT] [V] Constructing optimization profile number 0 [1/1].
[04/26/2022-21:37:57] [TRT] [V] Reserving memory for activation tensors. Host: 0 bytes Device: 7864320 bytes
[04/26/2022-21:37:57] [TRT] [V] =============== Computing reformatting costs
[04/26/2022-21:37:57] [TRT] [V] *************** Autotuning Reformat: Float(983040,122880,120,1) -> Half(983040,122880,120,1) ***************
[04/26/2022-21:37:57] [TRT] [V] --------------- Timing Runner: Optimizer Reformat(input -> <out>) (Reformat)
[04/26/2022-21:37:57] [TRT] [V] Tactic: 1002 Time: 0.024576
[04/26/2022-21:37:57] [TRT] [V] Tactic: 0 Time: 0.026624
[04/26/2022-21:37:57] [TRT] [V] Fastest Tactic: 1002 Time: 0.024576
[04/26/2022-21:37:57] [TRT] [V] *************** Autotuning Reformat: Float(983040,122880,120,1) -> Half(122880,1:8,120,1) ***************
[04/26/2022-21:37:57] [TRT] [V] --------------- Timing Runner: Optimizer Reformat(input -> <out>) (Reformat)
[04/26/2022-21:37:57] [TRT] [V] Tactic: 1002 Time: 0.02048
[04/26/2022-21:37:57] [TRT] [V] Tactic: 0 Time: 0.016384
[04/26/2022-21:37:57] [TRT] [V] Fastest Tactic: 0 Time: 0.016384
[04/26/2022-21:37:57] [TRT] [V] =============== Computing reformatting costs
[04/26/2022-21:37:57] [TRT] [V] *************** Autotuning Reformat: Half(983040,122880,120,1) -> Float(983040,122880,120,1) ***************
[04/26/2022-21:37:57] [TRT] [V] --------------- Timing Runner: Optimizer Reformat(<in> -> output) (Reformat)
[04/26/2022-21:37:57] [TRT] [V] Tactic: 1002 Time: 0.024576
[04/26/2022-21:37:57] [TRT] [V] Tactic: 0 Time: 0.024576
[04/26/2022-21:37:57] [TRT] [V] Fastest Tactic: 1002 Time: 0.024576
[04/26/2022-21:37:57] [TRT] [V] *************** Autotuning Reformat: Half(122880,1:8,120,1) -> Float(983040,122880,120,1) ***************
[04/26/2022-21:37:57] [TRT] [V] --------------- Timing Runner: Optimizer Reformat(<in> -> output) (Reformat)
[04/26/2022-21:37:57] [TRT] [V] Tactic: 1002 Time: 0.048128
[04/26/2022-21:37:57] [TRT] [V] Tactic: 0 Time: 0.014336
[04/26/2022-21:37:57] [TRT] [V] Fastest Tactic: 0 Time: 0.014336
[04/26/2022-21:37:57] [TRT] [V] =============== Computing costs for
[04/26/2022-21:37:57] [TRT] [V] *************** Autotuning format combination: Float(983040,122880,120,1) -> Float(983040,122880,120,1) ***************
[04/26/2022-21:37:57] [TRT] [V] --------------- Timing Runner: {ForeignNode[81 + (Unnamed Layer* 1) [Shuffle]...Add_17]} (Myelin)
***python: /root/gpgpu/MachineLearning/myelin/src/compiler/optimizer/kqv_gemm_split.cpp:350: void myelin::ir::kqv_split_pattern_t::check_transpose(): Assertion `in_dims.size() == 3' failed.***
Environment
TensorRT Version: ‘8.2.2.1’
GPU Type: A100
Nvidia Driver Version: 470.57.02
CUDA Version: 11.6
CUDNN Version: 8.3.2
Operating System + Version: “20.04.3 LTS (Focal Fossa)”
Python Version (if applicable): 3.8.12
TensorFlow Version (if applicable):
PyTorch Version (if applicable): 1.11.0a0+bfe5ad2’
Baremetal or Container (if container which image + tag):
Relevant Files
import torch
from torch import nn
import onnx
from onnxsim import simplify
import tensorrt as trt
class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
qkv = self.to_qkv(x)
q,k,v = torch.split(qkv, qkv.shape[3]//3, dim=3)
bs,ch,n,h = q.shape
q = self.qkv2MultiheadForm(q, bs, ch, n, self.heads, h//self.heads)
k = self.qkv2MultiheadForm(k, bs, ch, n, self.heads, h//self.heads)
v = self.qkv2MultiheadForm(v, bs, ch, n, self.heads, h//self.heads)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
out = torch.matmul(attn, v)
b, p, h, n, d = out.shape
out = self.multiheadForm2qkv(out, b, p, h, n, d)
return self.to_out(out)
def qkv2MultiheadForm(self, x, bs: int, ch: int, n: int, tmp_h: int, heads: int):
x = x.reshape(bs,ch,n,tmp_h,heads).permute(0,1,3,2,4)
return x
def multiheadForm2qkv(self, x, b: int, p: int, h: int, n: int, d: int):
return x.permute(0,1,3,2,4).reshape(b,p,n,h*d)
def torch2onnx(net, input, onnx_file, opver=13, do_simplify=True):
torch.onnx.export(net, input, onnx_file,
input_names=['input'],
output_names=['output'],
opset_version=opver,
do_constant_folding=True
)
if do_simplify:
# load your predefined ONNX model
sim_model = onnx.load(onnx_file)
# convert model
model_simp, check = simplify(sim_model)
assert check, "Simplified ONNX model could not be validated"
onnx.save(model_simp, onnx_file)
def torch2tensorrt(onnx_file, trt_file):
logger = trt.Logger(trt.Logger.VERBOSE)
builder = trt.Builder(logger)
network = builder.create_network(1<<int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
with trt.OnnxParser(network, logger) as parser:
success = parser.parse_from_file(onnx_file)
for idx in range(parser.num_errors):
print(parser.get_error(idx))
if not success:
print('parse onnx file failed.')
return
config = builder.create_builder_config()
# config.input
config.max_workspace_size=int(4<<30)
if builder.platform_has_fast_fp16:
config.set_flag(trt.BuilderFlag.FP16)
trt_engine = builder.build_engine(network, config)
serialized_engine = trt_engine.serialize()
with open(trt_file, "wb") as f:
f.write(serialized_engine)
return trt_engine
def torch_dtype_from_trt(dtype):
if dtype == trt.int8:
return torch.int8
elif dtype == trt.bool:
return torch.bool
elif dtype == trt.int32:
return torch.int32
elif dtype == trt.float16:
return torch.float16
elif dtype == trt.float32:
return torch.float32
else:
raise TypeError("%s is not supported by torch" % dtype)
def torch_device_to_trt(device):
if device.type == torch.device("cuda").type:
return trt.TensorLocation.DEVICE
elif device.type == torch.device("cpu").type:
return trt.TensorLocation.HOST
else:
return TypeError("%s is not supported by tensorrt" % device)
def torch_device_from_trt(device):
if device == trt.TensorLocation.DEVICE:
return torch.device("cuda")
elif device == trt.TensorLocation.HOST:
return torch.device("cpu")
else:
return TypeError("%s is not supported by torch" % device)
exit
def load_trt_model(trt_file):
logger = trt.Logger(trt.Logger.VERBOSE)
runtime = trt.Runtime(logger)
with open(trt_file, "rb") as f:
serialized_engine = f.read()
engine = runtime.deserialize_cuda_engine(serialized_engine)
return engine
def bench_mark(engine, img):
context = engine.create_execution_context()
input_idx = engine['input']
output_idx = engine['output']
buffers = [None] * 2 # Assuming 1 input and 1 output
input_ptr = img.contiguous().data_ptr()
shape = tuple(img,shape)
context.set_binding_shape(input_idx, shape)
dtype = torch_dtype_from_trt(engine.get_binding_dtype(output_idx))
device = torch_device_from_trt(engine.get_location(output_idx))
output = torch.empty(size=shape, dtype=dtype, device=device)
# outputs[i] = output
buffers[input_idx] = input_ptr
buffers[output_idx] = output.data_ptr()
context.execute_async(1, buffers, torch.cuda.current_stream().cuda_stream)
return output
def get_tensorrt_engine(net, img, onnx_file, trt_file):
torch2onnx(net, img, onnx_file)
engine = torch2tensorrt(onnx_file, trt_file)
return engine
if __name__=="__main__":
onnx_file = 'mvit.onnx'
trt_file = 'mvit.plan'
net = Attention(120,4,8)#SegNet()
net.eval()
img = torch.randn(1,8,1024,120).float()
out = net(img)
engine = get_tensorrt_engine(net, img, onnx_file, trt_file)
Steps To Reproduce
Please include:
- Exact steps/commands to build your repro
- Exact steps/commands to run your repro
- Full traceback of errors encountered