MatrixMultiply failed on TensorRT 7.2.1

Description

Hi
I am tring to update my code from TRT 7.1 to TRT 7.2.
I found add_matrix_multiply give right answer on TensorRT 7.1.3.4,but fail on TensorRT 7.2.1.6 with error log:

[TensorRT] INTERNAL ERROR: Assertion failed: cublasStatus == CUBLAS_STATUS_SUCCESS
../rtSafe/cublas/cublasLtWrapper.cpp:279

Environment

TensorRT Version: 7.1.3.4/ 7.2.1.6
GPU Type: 2070 Super
Nvidia Driver Version: 450.80.02
CUDA Version: 10.2
CUDNN Version: 8.0.4
Operating System + Version: ubuntu18.04
Python Version (if applicable): 3.7
TensorFlow Version (if applicable):
PyTorch Version (if applicable): 1.7.0
Baremetal or Container (if container which image + tag):

Steps To Reproduce

import tensorrt as trt
import torch
import numpy as np


def main():

    print("create trt model")
    log_level = trt.Logger.ERROR
    logger = trt.Logger(log_level)
    builder = trt.Builder(logger)

    ## build network
    EXPLICIT_BATCH = 1 << (int)(
        trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
    network = builder.create_network(EXPLICIT_BATCH)
    input1_name = 'input1'
    input2_name = 'input2'
    output_name = 'output'
    input1_trt = network.add_input(name=input1_name,
                                   shape=[1, 4100, 25],
                                   dtype=trt.float32)
    input2_trt = network.add_input(name=input2_name,
                                   shape=[1, 25, 4100],
                                   dtype=trt.float32)

    # matmul
    matmul_trt = network.add_matrix_multiply(
        input1_trt, trt.MatrixOperation.NONE, input2_trt,
        trt.MatrixOperation.NONE).get_output(0)

    output = matmul_trt
    output.name = output_name
    network.mark_output(output)

    ## builder config
    max_workspace_size = 1 << 30
    fp16_mode = False

    builder.max_workspace_size = max_workspace_size
    builder.fp16_mode = fp16_mode

    config = builder.create_builder_config()
    config.max_workspace_size = max_workspace_size
    profile = builder.create_optimization_profile()

    # set shape
    input1_shape = (1, 4100, 25)
    profile.set_shape(input1_name, input1_shape, input1_shape, input1_shape)
    input2_shape = (1, 25, 4100)
    profile.set_shape(input2_name, input2_shape, input2_shape, input2_shape)
    config.add_optimization_profile(profile)
    if fp16_mode:
        config.set_flag(trt.BuilderFlag.FP16)

    # build engine
    engine = builder.build_engine(network, config)
    context = engine.create_execution_context()

    print("inference")
    input1_torch = torch.rand(1, 4100, 25).cuda().contiguous()
    input2_torch = torch.rand(1, 25, 4100).cuda().contiguous()

    bindings = [None] * 3

    # set input
    idx = engine.get_binding_index(input1_name)
    context.set_binding_shape(idx, tuple(input1_torch.shape))
    bindings[idx] = input1_torch.data_ptr()
    idx = engine.get_binding_index(input2_name)
    context.set_binding_shape(idx, tuple(input2_torch.shape))
    bindings[idx] = input2_torch.data_ptr()

    # set output
    idx = engine.get_binding_index(output_name)
    shape = tuple(context.get_binding_shape(idx))
    output_torch = torch.empty(shape, dtype=torch.float32).cuda()
    bindings[idx] = output_torch.data_ptr()

    context.execute_async_v2(bindings, torch.cuda.current_stream().cuda_stream)

    print(output_torch.shape)
    print(output_torch.view(-1)[:10])


if __name__ == "__main__":
    main()

On TensorRT 7.1.3.4 the log:

create trt model
inference
torch.Size([1, 4100, 4100])
tensor([5.5353, 4.6237, 5.2224, 7.7332, 6.3585, 4.7470, 5.5884, 4.8122, 6.1835,
        5.2994], device='cuda:0')

On TensorRT 7.2.1.6 the log:

create trt model
inference
[TensorRT] INTERNAL ERROR: Assertion failed: cublasStatus == CUBLAS_STATUS_SUCCESS
../rtSafe/cublas/cublasLtWrapper.cpp:279
Aborting...
[TensorRT] ERROR: FAILED_EXECUTION: std::exception
torch.Size([1, 4100, 4100])
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:0')

Did I miss some config on 7.2.1.6? How do I fix this?
Thanks

1 Like

Getting the same error here. Could someone help please?

1 Like

Getting same issue here. Tensorrt 7.2.1

1 Like

Solved by upgrading cuBlas to 10.2.2.214. (CUDA 10.2 patch)

1 Like

Hi
TRT 7.2.1 switches to use cuBLASLt (previously it was cuBLAS). cuBLASLt is the defaulted choice for SM version >= 7.0. However,you may need CUDA-10.2 Patch 1 (Released Aug 26, 2020) to resolve some cuBLASLt issues. Another option is to use the new TacticSource API and disable cuBLASLt tactics if you dont want to upgrade.

Thanks!

3 Likes