MultiHeadAttnBackwardData Wrong Result with postDropout enabled

The two methods

  1. use embedded Dropout of MultiHeadAttn
  2. disable embedded Dropout and use standalone Dropout

Expected: produce identical outputs and gradients
Actual: outputs are identical but gradients differs considerably

Furthermore, Python implementation agrees with 2. Combining the evidences, it look like MultiHeadAttnBackwardData is not working correctly when postDropout is enabled.

Reproduction below (compile with clang++-12 -std=c++20 -lcudnn -lcudart)

#include <cuda_runtime.h>
#include <cuda_runtime_api.h>
#include <driver_types.h>
#include <unistd.h>
#include <pthread.h>
#include <algorithm>
#include <cstddef>
#include <cstdio>
#include <cstdlib>
#include <cudnn.h>
#include "cudnn_ops_infer.h"
#include "cudnn_ops_train.h"
#include <cudnn_adv_infer.h>
#include <cudnn_adv_train.h>

#include <memory>
#include <tuple>
#include <cstring>
#include <random>
#include <type_traits>
#include <vector>
#include <span>

inline void checkCudaError(cudaError_t code, const char *expr, const char *file, int line) {
    if (code) {
        fprintf(stderr, "ERROR: CUDA error at %s:%d, code=%d (%s) in '%s'\n\n",
                file, line, (int)code, cudaGetErrorString(code), expr);
        exit(1);
    }
}

inline void checkCudnnError(cudnnStatus_t code, const char *expr, const char *file, int line) {
    if (code) {
        fprintf(stderr, "CUDNN error at %s:%d, code=%d (%s) in '%s'\n\n",
                file, line, (int)code, cudnnGetErrorString(code), expr);
        exit(1);
    }
}

inline void do_assert(bool succ, const char *expr, const char *file, int line) {
    if (!succ) {
        fprintf(stderr, "assertion failed at %s:%d in '%s'\n\n",
                file, line, expr);
        exit(1);
    }
}

#define CHECK_CUDA_ERR(...)                                             \
    do {                                                                \
        checkCudaError(__VA_ARGS__, #__VA_ARGS__, __FILE__, __LINE__);  \
    } while (0)

#define CHECK_CUDNN_ERR(...)                                            \
    do {                                                                \
        checkCudnnError(__VA_ARGS__, #__VA_ARGS__, __FILE__, __LINE__); \
    } while (0)

#define ASSERT(...)                                                     \
    do {                                                                \
        do_assert(__VA_ARGS__, #__VA_ARGS__, __FILE__, __LINE__); \
    } while (0)


template<typename T, typename Tag>
struct Span : std::span<T> {
    using std::span<T>::span;

    template<typename U>
    Span<U, Tag> as() {
        return {static_cast<U*>(std::span<T>::data()), std::span<T>::size_bytes() / sizeof(U)};
    }
};

struct HostTraits {
    static void* alloc(size_t size) {
        void* ptr;
        CHECK_CUDA_ERR(cudaMallocHost(&ptr, size));
        return ptr;
    }

    static void free(void* ptr) {
        CHECK_CUDA_ERR(cudaFreeHost(ptr));
    }
};

struct DevTraits {
    static void* alloc(size_t size) {
        void* ptr;
        CHECK_CUDA_ERR(cudaMalloc(&ptr, size));
        return ptr;
    }

    static void free(void* ptr) {
        CHECK_CUDA_ERR(cudaFree(ptr));
    }
};

template <typename Traits>
struct Deleter {
    void operator()(void* ptr) {
        Traits::free(ptr);
    }
};

template<typename T, typename Traits>
struct Array : std::unique_ptr<T, Deleter<Traits>>, Span<T, Traits> {
    Array(size_t sz) :
            std::unique_ptr<T, Deleter<Traits>>(static_cast<T*>(Traits::alloc(sz * sizeof(T)))),
            Span<T, Traits>(std::unique_ptr<T, Deleter<Traits>>::get(), sz) {}
    Array(std::nullptr_t) {}
    Array() = default;
};

template <typename T> using DevArray = Array<T, DevTraits>;
template <typename T> using HostArray = Array<T, HostTraits>;
template <typename T> using DevSpan = Span<T, DevTraits>;
template <typename T> using HostSpan = Span<T, HostTraits>;


template <typename T>
void copy(HostSpan<T> dst, DevSpan<T> src) {
    ASSERT(dst.size() == src.size());
    cudaMemcpy(dst.data(), src.data(), dst.size_bytes(), cudaMemcpyDeviceToHost);
}

template <typename T>
void copy(DevSpan<T> dst, HostSpan<T> src) {
    ASSERT(dst.size() == src.size());
    cudaMemcpy(dst.data(), src.data(), dst.size_bytes(), cudaMemcpyHostToDevice);
}


template<typename G, typename T>
void fill_rand(DevSpan<T> span, G&& gen) {
    HostArray<T> hdata(span.size());
    std::uniform_real_distribution<T> dis;
    std::generate(hdata.begin(), hdata.end(), [&]{return dis(gen);});
    copy(span, hdata);
}


template<typename T>
void fill_const(DevSpan<T> span, T value) {
    HostArray<T> hdata(span.size());
    std::fill(hdata.begin(), hdata.end(), value);
    copy(span, hdata);
}


size_t get_seq_size(cudnnSeqDataDescriptor_t desc) {
    int nbDims, dimA[4];
    cudnnDataType_t dtype;
    cudnnSeqDataAxis_t axes[4];
    size_t seqLengthArraySize;
    cudnnGetSeqDataDescriptor(desc, &dtype, &nbDims, 4, dimA, axes, &seqLengthArraySize, 0, NULL, NULL);
    size_t sz = 1;
    for (int i = 0; i < 4; ++i) {
        sz *= dimA[i];
    }
    return sz;
};


void test_attn() {
    int N = 32,
        B = 1,
        H = 8,
        T = 1024,
        Tk = 1024,
        Cq = 32,
        Ck = 32,
        Cv = 32,
        Cq_ = 32,
        Cv_ = 32,
        Co_ = 32;

    cudnnHandle_t handle;
    CHECK_CUDNN_ERR(cudnnCreate(&handle));

    size_t state_size = 0;
    CHECK_CUDNN_ERR(cudnnDropoutGetStatesSize(handle, &state_size));

    void* state;
    CHECK_CUDA_ERR(cudaMalloc(&state, state_size));

    cudnnDropoutDescriptor_t post_drop_desc;
    CHECK_CUDNN_ERR(cudnnCreateDropoutDescriptor(&post_drop_desc));

    auto make_attn_desc = [&](bool dropout) {
        cudnnAttnDescriptor_t attn_desc;
        CHECK_CUDNN_ERR(cudnnCreateAttnDescriptor(&attn_desc));
        CHECK_CUDNN_ERR(cudnnSetAttnDescriptor(
            attn_desc,
            CUDNN_ATTN_ENABLE_PROJ_BIASES,
            H,
            0.1,
            CUDNN_DATA_FLOAT,
            CUDNN_DATA_FLOAT,
            CUDNN_DEFAULT_MATH,
            NULL,
            dropout ? post_drop_desc : NULL,
            Cq, // qSize
            Ck, // kSize
            Cv, // vSize,
            Cq_, // qProjSize
            Cq_, // kProjSize
            Cv_, // vProjSize
            Co_, // oProjSize
            T, // qoMaxSeqLength
            Tk, // kvMaxSeqLength,
            N, // maxBatchSize
            B // maxBeamSize
        ));
        return attn_desc;
    };

    std::minstd_rand gen(42);

    auto make_seq_desc = [&](int N, int B, int T, int C, const int *seqLengthArray, size_t seqLengthArraySize) {
        cudnnSeqDataDescriptor_t desc;
        CHECK_CUDNN_ERR(cudnnCreateSeqDataDescriptor(&desc));
        int dimA[CUDNN_SEQDATA_DIM_COUNT];
        dimA[CUDNN_SEQDATA_BATCH_DIM] = N;
        dimA[CUDNN_SEQDATA_BEAM_DIM] = B;
        dimA[CUDNN_SEQDATA_TIME_DIM] = T;
        dimA[CUDNN_SEQDATA_VECT_DIM] = C;
        cudnnSeqDataAxis_t axes[CUDNN_SEQDATA_DIM_COUNT] = {
            CUDNN_SEQDATA_BATCH_DIM,
            CUDNN_SEQDATA_BEAM_DIM,
            CUDNN_SEQDATA_TIME_DIM,
            CUDNN_SEQDATA_VECT_DIM
        };
        CHECK_CUDNN_ERR(cudnnSetSeqDataDescriptor(desc, CUDNN_DATA_FLOAT, 4, dimA, axes, seqLengthArraySize, seqLengthArray, NULL));
        return desc;
    };

    std::vector<int> qSeqLengthArray(N * B, T), kSeqLengthArray(N, Tk);

    size_t weightSizeInBytes, _unused_size_t;
    CHECK_CUDNN_ERR(cudnnGetMultiHeadAttnBuffers(
        handle, make_attn_desc(false), &weightSizeInBytes, &_unused_size_t, &_unused_size_t
    ));

    auto qDesc = make_seq_desc(N, B, T, Cq, &qSeqLengthArray[0], qSeqLengthArray.size());
    auto kDesc = make_seq_desc(N, 1, Tk, Ck, &kSeqLengthArray[0], kSeqLengthArray.size());
    auto vDesc = make_seq_desc(N, 1, Tk, Cv, &kSeqLengthArray[0], kSeqLengthArray.size());
    auto oDesc = make_seq_desc(N, B, T, Co_, &qSeqLengthArray[0], qSeqLengthArray.size());

    auto make_rand_seq = [&](cudnnSeqDataDescriptor_t desc, auto gen) {
        DevArray<float> data(get_seq_size(desc));
        fill_rand(data, gen);
        return std::move(data);
    };

    DevArray<float>
            queries(get_seq_size(qDesc)),
            dqueries(get_seq_size(qDesc)),
            keys(get_seq_size(kDesc)),
            dkeys(get_seq_size(kDesc)),
            values(get_seq_size(vDesc)),
            dvalues(get_seq_size(vDesc)),
            out(get_seq_size(oDesc)),
            dout(get_seq_size(oDesc)),
            weights(weightSizeInBytes / sizeof(float)),
            dweights(weightSizeInBytes / sizeof(float))
            ;

    fill_rand(weights, gen);
    fill_rand(queries, gen);
    fill_rand(keys, gen);
    fill_rand(values, gen);
    fill_rand(dout, gen);

    // fill_const(weights, 1.0f);
    fill_const(queries, 1.0f);
    fill_const(keys, 1.0f);
    fill_const(values, 1.0f);
    fill_const(dout, 1.0f);

    void *residuals = NULL;

    std::vector<int> loWinIdx(T, 0), hiWinIdx(T, Tk);
    DevArray<int> devSeqLengthsQO(N * B), devSeqLengthsKV(N * 1);
    fill_const(devSeqLengthsQO, T);
    fill_const(devSeqLengthsKV, Tk);

    auto attn_fwd_bwd = [&](bool dropout) {
        CHECK_CUDNN_ERR(cudnnSetDropoutDescriptor(post_drop_desc, handle, 0.5, state, state_size, 42));
        auto attn_desc = make_attn_desc(dropout);

        size_t workSpaceSizeInBytes, reserveSpaceSizeInBytes;
        CHECK_CUDNN_ERR(cudnnGetMultiHeadAttnBuffers(
            handle, attn_desc, &_unused_size_t, &workSpaceSizeInBytes, &reserveSpaceSizeInBytes
        ));

        DevArray<char> workSpace(workSpaceSizeInBytes), reserveSpace(reserveSpaceSizeInBytes);

        CHECK_CUDNN_ERR(cudnnMultiHeadAttnForward(
            handle,
            attn_desc,
            -1,
            &loWinIdx[0],
            &hiWinIdx[0],
            devSeqLengthsQO.get(),
            devSeqLengthsKV.get(),
            qDesc,
            queries.get(),
            residuals,
            kDesc,
            keys.get(),
            vDesc,
            values.get(),
            oDesc,
            out.get(),
            weightSizeInBytes,
            weights.get(),
            workSpaceSizeInBytes,
            workSpace.get(),
            reserveSpaceSizeInBytes,
            reserveSpace.get()
        ));

        HostArray<float> hout(out.size());
        DevArray<float> dout1;

        if (!dropout) {
            cudnnTensorDescriptor_t tdesc;
            CHECK_CUDNN_ERR(cudnnCreateTensorDescriptor(&tdesc));

            int dimA[4] = {1, 1, 1, 1}, strideA[4] = {1, 1, 1, 1};
            dimA[0] = out.size();
            CHECK_CUDNN_ERR(cudnnSetTensorNdDescriptor(tdesc, CUDNN_DATA_FLOAT, 4, dimA, strideA));

            size_t reserve_size = 0;
            CHECK_CUDNN_ERR(cudnnDropoutGetReserveSpaceSize(tdesc, &reserve_size));

            DevArray<char> reserveSpace(reserveSpaceSizeInBytes);

            CHECK_CUDNN_ERR(cudnnDropoutForward(
                handle,
                post_drop_desc,
                tdesc,
                out.get(),
                tdesc,
                out.get(),
                reserveSpace.get(),
                reserveSpaceSizeInBytes
            ));

            copy(hout, out);

            dout1 = DevArray<float>(dout.size());

            CHECK_CUDNN_ERR(cudnnDropoutBackward(
                handle,
                post_drop_desc,
                tdesc,
                dout.get(),
                tdesc,
                dout1.get(),
                reserveSpace.get(),
                reserve_size
            ));
        } else {
            copy(hout, out);
        }

        CHECK_CUDNN_ERR(cudnnMultiHeadAttnBackwardData(
            handle,
            attn_desc,
            &loWinIdx[0],
            &hiWinIdx[0],
            devSeqLengthsQO.get(),
            devSeqLengthsKV.get(),
            oDesc,
            dout1 ? dout1.get() : dout.get(),
            qDesc,
            dqueries.get(),
            queries.get(),
            kDesc,
            dkeys.get(),
            keys.get(),
            vDesc,
            dvalues.get(),
            values.get(),
            weightSizeInBytes,
            weights.get(),
            workSpaceSizeInBytes,
            workSpace.get(),
            reserveSpaceSizeInBytes,
            reserveSpace.get()
        ));

        HostArray<float> hdq(dqueries.size());
        copy(hdq, dqueries);

        return std::make_tuple(std::move(hout), std::move(hdq));
    };

    auto [hout1, hdq1] = attn_fwd_bwd(true);
    auto [hout2, hdq2] = attn_fwd_bwd(false);

    for (int i = 0; i < hout1.size(); ++i) {
        ASSERT(hout1[i] == hout2[i]);
    }
    for (int i = 0; i < hdq1.size(); ++i) {
        ASSERT(hdq1[i] == hdq2[i]);
    }
}

int main() {
    test_attn();
    return 0;
}