Matrix Multiplication -> PointWise Operation is Always Read as an MHA Pattern

I am trying to implement a forward attention algorithm utilizing cuDNN backend API with version 9.1.1. As concluded in this discussion I am unable to use the MHA pattern as it is only supported on Hopper GPUs. I have now changed my strategy to use two graphs (matmul->scale and softmax->matmul) to represent the attention operation.

I am currently implementing the first graph which I believe subscribes to the first graph pattern listed in this section. I have also made all tensors follow the NHWC layout as stated here.

When I run the following implementation though, I get 0 engine configurations returned by the heuristic and warning message: Warning: CUDNN_STATUS_NOT_SUPPORTED; Reason: mha patterns only support 2, 4 or 5 gemms at: number_of_gemms != 2 && number_of_gemms != 4 && number_of_gemms != 5. I am not sure why the graph is being read as an MHA graph and not a MatMul fusion graph.

What should I change to make the graph be a proper MatMul fusion graph? Is it possible to combine a matrix multiplication and pointwise operation in a non-MHA graph.

#include <cudnn.h>
#include <iostream>
#include <vector>

#define CUDNN_CHECK(status)                                                    \
    {                                                                          \
        if (status != CUDNN_STATUS_SUCCESS) {                                  \
            std::cerr << "cuDNN error at " << __FILE__ << ":" << __LINE__      \
                      << ": " << cudnnGetErrorString(status) << std::endl;     \
            char error_message[256];                                           \
            cudnnGetLastErrorString(error_message, sizeof(error_message));     \
            std::cerr << "reason: " << error_message;                          \
            std::cerr << std::endl;                                            \
            std::exit(EXIT_FAILURE);                                           \
        }                                                                      \
    }

void print_vector(const std::vector<int64_t> &v, std::string name) {
    std::cout << name << ": " << v.size() << ": ";
    for (int64_t i : v) {
        std::cout << i << " ";
    }
    std::cout << std::endl;
}

cudnnBackendDescriptor_t tensor_descriptor(const std::vector<int64_t> &shape,
                                           const std::vector<int64_t> &strides,
                                           int64_t id,
                                           cudnnDataType_t data_type,
                                           int64_t byte_alignment,
                                           bool is_virtual) {
    cudnnBackendDescriptor_t desc;
    CUDNN_CHECK(
        cudnnBackendCreateDescriptor(CUDNN_BACKEND_TENSOR_DESCRIPTOR, &desc));
    CUDNN_CHECK(cudnnBackendSetAttribute(desc, CUDNN_ATTR_TENSOR_UNIQUE_ID,
                                         CUDNN_TYPE_INT64, 1, &id));
    CUDNN_CHECK(cudnnBackendSetAttribute(desc, CUDNN_ATTR_TENSOR_DATA_TYPE,
                                         CUDNN_TYPE_DATA_TYPE, 1, &data_type));
    CUDNN_CHECK(cudnnBackendSetAttribute(desc, CUDNN_ATTR_TENSOR_BYTE_ALIGNMENT,
                                         CUDNN_TYPE_INT64, 1, &byte_alignment));
    CUDNN_CHECK(cudnnBackendSetAttribute(desc, CUDNN_ATTR_TENSOR_DIMENSIONS,
                                         CUDNN_TYPE_INT64,
                                         (int64_t)shape.size(), shape.data()));
    CUDNN_CHECK(cudnnBackendSetAttribute(
        desc, CUDNN_ATTR_TENSOR_STRIDES, CUDNN_TYPE_INT64,
        (int64_t)strides.size(), strides.data()));
    CUDNN_CHECK(cudnnBackendSetAttribute(desc, CUDNN_ATTR_TENSOR_IS_VIRTUAL,
                                         CUDNN_TYPE_BOOLEAN, 1, &is_virtual));
    CUDNN_CHECK(cudnnBackendFinalize(desc));
    return desc;
}

std::vector<int64_t> nhwc_strides(const std::vector<int64_t> &shape) {
    return {shape[1] * shape[2] * shape[3], 1, shape[1] * shape[3], shape[1]};
}

std::vector<int64_t> nchw_strides(const std::vector<int64_t> &shape) {
    return {shape[1] * shape[2] * shape[3], shape[2] * shape[3], shape[3], 1};
}

int main() {
    std::vector<int64_t> shape_query = {1, 4, 10, 64};
    std::vector<int64_t> stride_query = nchw_strides(shape_query);

    std::vector<int64_t> shape_scaler = {1, 1, 1, 1};
    std::vector<int64_t> stride_scaler = nhwc_strides(shape_scaler);

    std::vector<int64_t> shape_key = {1, 4, 10, 64};
    std::vector<int64_t> stride_key = nchw_strides(shape_key);
    std::swap(shape_key[2], shape_key[3]);
    std::swap(stride_key[1], stride_key[2]);

    std::vector<int64_t> shape_scores = {1, 4, 10, 10};
    std::vector<int64_t> stride_scores = nhwc_strides(shape_scores);

    cudnnHandle_t handle;
    CUDNN_CHECK(cudnnCreate(&handle));
    cudnnDataType_t data_type = CUDNN_DATA_FLOAT;
    int64_t data_type_byte_alignment = 4;

    cudnnBackendDescriptor_t query =
        tensor_descriptor(shape_query, stride_query, 'q', data_type,
                          data_type_byte_alignment, false);
    cudnnBackendDescriptor_t key = tensor_descriptor(
        shape_key, stride_key, 'k', data_type, data_type_byte_alignment, false);
    cudnnBackendDescriptor_t scores =
        tensor_descriptor(shape_scores, stride_scores, 'o', data_type,
                          data_type_byte_alignment, true);
    cudnnBackendDescriptor_t scaler =
        tensor_descriptor(shape_scaler, stride_scaler, 'r', data_type,
                          data_type_byte_alignment, false);
    cudnnBackendDescriptor_t scaled =
        tensor_descriptor(shape_scores, stride_scores, 's', data_type,
                          data_type_byte_alignment, false);

    cudnnBackendDescriptor_t matmul_desc;
    CUDNN_CHECK(cudnnBackendCreateDescriptor(CUDNN_BACKEND_MATMUL_DESCRIPTOR,
                                             &matmul_desc));
    CUDNN_CHECK(cudnnBackendSetAttribute(matmul_desc,
                                         CUDNN_ATTR_MATMUL_COMP_TYPE,
                                         CUDNN_TYPE_DATA_TYPE, 1, &data_type));
    CUDNN_CHECK(cudnnBackendFinalize(matmul_desc));
    cudnnBackendDescriptor_t op_matmul;
    CUDNN_CHECK(cudnnBackendCreateDescriptor(
        CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR, &op_matmul));
    CUDNN_CHECK(cudnnBackendSetAttribute(
        op_matmul, CUDNN_ATTR_OPERATION_MATMUL_DESC,
        CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &matmul_desc));
    CUDNN_CHECK(
        cudnnBackendSetAttribute(op_matmul, CUDNN_ATTR_OPERATION_MATMUL_ADESC,
                                 CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &query));
    CUDNN_CHECK(
        cudnnBackendSetAttribute(op_matmul, CUDNN_ATTR_OPERATION_MATMUL_BDESC,
                                 CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &key));
    CUDNN_CHECK(
        cudnnBackendSetAttribute(op_matmul, CUDNN_ATTR_OPERATION_MATMUL_CDESC,
                                 CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &scores));
    CUDNN_CHECK(cudnnBackendFinalize(op_matmul));

    cudnnBackendDescriptor_t pw_mul;
    CUDNN_CHECK(cudnnBackendCreateDescriptor(CUDNN_BACKEND_POINTWISE_DESCRIPTOR,
                                             &pw_mul));
    cudnnPointwiseMode_t mul_mode = CUDNN_POINTWISE_MUL;
    CUDNN_CHECK(cudnnBackendSetAttribute(pw_mul, CUDNN_ATTR_POINTWISE_MODE,
                                         CUDNN_TYPE_POINTWISE_MODE, 1,
                                         &mul_mode));
    CUDNN_CHECK(cudnnBackendSetAttribute(pw_mul, CUDNN_ATTR_POINTWISE_MATH_PREC,
                                         CUDNN_TYPE_DATA_TYPE, 1, &data_type));
    CUDNN_CHECK(cudnnBackendFinalize(pw_mul));
    cudnnBackendDescriptor_t op_scale;
    CUDNN_CHECK(cudnnBackendCreateDescriptor(
        CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR, &op_scale));
    CUDNN_CHECK(cudnnBackendSetAttribute(
        op_scale, CUDNN_ATTR_OPERATION_POINTWISE_PW_DESCRIPTOR,
        CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &pw_mul));
    CUDNN_CHECK(
        cudnnBackendSetAttribute(op_scale, CUDNN_ATTR_OPERATION_POINTWISE_XDESC,
                                 CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &scores));
    CUDNN_CHECK(
        cudnnBackendSetAttribute(op_scale, CUDNN_ATTR_OPERATION_POINTWISE_BDESC,
                                 CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &scaler));
    CUDNN_CHECK(
        cudnnBackendSetAttribute(op_scale, CUDNN_ATTR_OPERATION_POINTWISE_YDESC,
                                 CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &scaled));
    CUDNN_CHECK(cudnnBackendFinalize(op_scale));

    cudnnBackendDescriptor_t op_graph;
    cudnnBackendDescriptor_t ops[] = {op_matmul, op_scale};
    CUDNN_CHECK(cudnnBackendCreateDescriptor(
        CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR, &op_graph));
    CUDNN_CHECK(cudnnBackendSetAttribute(
        op_graph, CUDNN_ATTR_OPERATIONGRAPH_OPS, CUDNN_TYPE_BACKEND_DESCRIPTOR,
        sizeof(ops) / sizeof(ops[0]), ops));
    CUDNN_CHECK(cudnnBackendSetAttribute(op_graph,
                                         CUDNN_ATTR_OPERATIONGRAPH_HANDLE,
                                         CUDNN_TYPE_HANDLE, 1, &handle));
    CUDNN_CHECK(cudnnBackendFinalize(op_graph));

    cudnnBackendDescriptor_t heur_desc;
    CUDNN_CHECK(cudnnBackendCreateDescriptor(
        CUDNN_BACKEND_ENGINEHEUR_DESCRIPTOR, &heur_desc));
    CUDNN_CHECK(cudnnBackendSetAttribute(
        heur_desc, CUDNN_ATTR_ENGINEHEUR_OPERATION_GRAPH,
        CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &op_graph));
    cudnnBackendHeurMode_t heur_mode = CUDNN_HEUR_MODE_A;
    CUDNN_CHECK(cudnnBackendSetAttribute(heur_desc, CUDNN_ATTR_ENGINEHEUR_MODE,
                                         CUDNN_TYPE_HEUR_MODE, 1, &heur_mode));
    CUDNN_CHECK(cudnnBackendFinalize(heur_desc));
    int64_t count = 0;
    CUDNN_CHECK(cudnnBackendGetAttribute(
        heur_desc, CUDNN_ATTR_ENGINEHEUR_RESULTS, CUDNN_TYPE_BACKEND_DESCRIPTOR,
        0, &count, NULL));
    std::cout << "engines: " << count << "\n";
    std::vector<cudnnBackendDescriptor_t> eng_cfgs(count);
    for (cudnnBackendDescriptor_t &cfg : eng_cfgs) {
        CUDNN_CHECK(cudnnBackendCreateDescriptor(
            CUDNN_BACKEND_ENGINECFG_DESCRIPTOR, &cfg));
    }
    CUDNN_CHECK(cudnnBackendGetAttribute(
        heur_desc, CUDNN_ATTR_ENGINEHEUR_RESULTS, CUDNN_TYPE_BACKEND_DESCRIPTOR,
        count, nullptr, eng_cfgs.data()));

    for (cudnnBackendDescriptor_t &cfg : eng_cfgs) {
        cudnnBackendDescriptor_t exec_plan;
        CUDNN_CHECK(cudnnBackendCreateDescriptor(
            CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR, &exec_plan));
        CUDNN_CHECK(cudnnBackendSetAttribute(
            exec_plan, CUDNN_ATTR_EXECUTION_PLAN_ENGINE_CONFIG,
            CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &cfg));
        CUDNN_CHECK(cudnnBackendSetAttribute(exec_plan,
                                             CUDNN_ATTR_EXECUTION_PLAN_HANDLE,
                                             CUDNN_TYPE_HANDLE, 1, &handle));

        cudnnStatus_t status = cudnnBackendFinalize(exec_plan);
        std::cout << cudnnGetErrorString(status) << "\n";
        if (status == CUDNN_STATUS_SUCCESS) {
            std::cout << "success\n";
        } else {
            char error_message[256];
            cudnnGetLastErrorString(error_message, sizeof(error_message));
            std::cout << "reason: " << error_message << std::endl;
        }
    }

    // To be filled in

    return 0;
}

In this blog, at the end, there is an example of a working graph with a single MatMul operation. This can easily be extended with Pointwise-Operations like explained in your Link. Maybe this helps.
The blog is in german, but the code should be clear.

I am able to finalize the execution plan for a rank 3 tensor matrix multiplication into a pointwise operation when utilizing more recent hardware and after upgrading to version 9.6.0.

However when attempting to do the same with rank 4 tensors I get either one of these two errors for the 10 engines returned:
CUDNN_STATUS_INTERNAL_ERROR_COMPILATION_FAILED reason: Encountered runtime kernel compilation failure at: compilationResult != NVRTC_SUCCESS CUDNN_STATUS_BAD_PARAM_NOT_FINALIZED reason: assert_finalized == true
with no other errors or warnings logged.

I also get the same results when just doing a simple rank 4 matrix multiplication. I am not sure why this is when the documentation states that two batch dimensions are supported, but maybe they aren’t on my hardware. I think for now I am just going to attempt to use the deprecated APIs, the graph API causes too much difficulty.