Why my ldmatrix PTX instruction is wrong?

I came across an interesting problem related to ‘ldmtrix’ while searching. I copied the code and tried to debug it. However, when I made slight modifications, the code didn’t produce the correct results. Here is my code

#include <iostream>

__device__ __forceinline__ uint32_t cvta_to_shared_u32(const void *pointer)
{
    uint32_t address;
    asm("{\n\t"
        "  .reg .u64 u64addr;\n\t"
        "  cvta.to.shared.u64 u64addr, %1;\n\t"
        "  cvt.u32.u64 %0, u64addr;\n\t"
        "}"
        : "=r"(address)
        : "l"(pointer));
    return address;
}

__device__ __forceinline__ void ldmatrix_sync_aligned_m8n8_x1_b16(
    uint32_t &d0, uint32_t &d1, uint32_t &d2, uint32_t &d3,
    const uint32_t &address)
{
    asm volatile(
        "ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];"
        : "=r"(d0)
        : "r"(address));
}

__global__ void ldmatrix_bank_conflict(uint16_t *value)
{
    constexpr int N = 4096;
    __shared__ uint16_t smem[N];
    auto tid = threadIdx.x;
    // bank: tid * 64 * 2 / 4 % 32 = 0
    // T0 read: smem[0],   ..., smem[7]   -> bank 0 to bank 3
    // T1 read: smem[64],  ..., smem[71]  -> bank 0 to bank 3
    // T2 read: smem[128], ..., smem[135] -> bank 0 to bank 3
    // T3 read: smem[192], ..., smem[199] -> bank 0 to bank 3
    // T4 read: smem[256], ..., smem[263] -> bank 0 to bank 3
    // T5 read: smem[320], ..., smem[327] -> bank 0 to bank 3
    // T6 read: smem[384], ..., smem[391] -> bank 0 to bank 3
    // T7 read: smem[448], ..., smem[455] -> bank 0 to bank 3
    const uint32_t address =
        cvta_to_shared_u32(smem) + sizeof(uint16_t) * (tid * 8);

    for (uint32_t i = tid; i < N; i += blockDim.x)
    {
        smem[i] = i;
    }
    __syncthreads();

    uint32_t frag[4];
    ldmatrix_sync_aligned_m8n8_x1_b16(
        frag[0], frag[1], frag[2], frag[3], address);

    __syncthreads();

    uint16_t number1 = static_cast<uint16_t>(frag[0] & 0xFFFF);
    uint16_t number2 = static_cast<uint16_t>((frag[0] >> 16) & 0xFFFF);
    uint16_t number3 = static_cast<uint16_t>(frag[1] & 0xFFFF);
    uint16_t number4 = static_cast<uint16_t>((frag[1] >> 16) & 0xFFFF);
    printf("%d  %d   %d  %d   %d   \n", smem[64 * tid], number1, number2, number3, number4);
}

int main()
{
    uint16_t *d_value;
    cudaMalloc(&d_value, sizeof(uint16_t));
    ldmatrix_bank_conflict<<<1, 32>>>(d_value);
    cudaDeviceSynchronize();
    cudaFree(d_value);
    return 0;
}

I want to load an 8x8 matrix, but when N=64, the code throws an error. However, when N=4096, the code executes correctly. Based on the documentation I found regarding ldmatrix, it states that thread0-7 corresponds to addr0-7. So, when the thread value exceeds 7, the address goes beyond the range of the shared memory I allocated. What should I do in this situation?

  1. when N=64, your shared memory definition looks like this:

    __shared__ uint16_t smem[64];
    

    So this is going to be a problem:

2. %d is an incorrect printf format specifier for a 16 bit quantity. We will fix that by casting it to int rather than trying to change the format specifier.

  1. This will also be a problem:

Note the documentation:

For .target sm_75 or below, all threads must contain valid addresses. Otherwise, the behavior is undefined. For .num = .x1 and .num = .x2 , addresses contained in lower threads can be copied to higher threads to achieve the expected behavior.

So we prefer (for the .x1 case) to see a construct like this:

const uint32_t address =
    cvta_to_shared_u32(smem) + sizeof(uint16_t) * ((tid%8) * 8);

This seems to produce sensible output for me:

# cat t164.cu
#include <iostream>

__device__ __forceinline__ uint32_t cvta_to_shared_u32(const void *pointer)
{
    uint32_t address;
    asm("{\n\t"
        "  .reg .u64 u64addr;\n\t"
        "  cvta.to.shared.u64 u64addr, %1;\n\t"
        "  cvt.u32.u64 %0, u64addr;\n\t"
        "}"
        : "=r"(address)
        : "l"(pointer));
    return address;
}

__device__ __forceinline__ void ldmatrix_sync_aligned_m8n8_x1_b16(
    uint32_t &d0, uint32_t &d1, uint32_t &d2, uint32_t &d3,
    const uint32_t &address)
{
    asm volatile(
        "ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];"
        : "=r"(d0)
        : "r"(address));
}

__global__ void ldmatrix_bank_conflict(uint16_t *value)
{
    constexpr int N = 4096;
    __shared__ uint16_t smem[N];
    auto tid = threadIdx.x;
    // bank: tid * 64 * 2 / 4 % 32 = 0
    // T0 read: smem[0],   ..., smem[7]   -> bank 0 to bank 3
    // T1 read: smem[64],  ..., smem[71]  -> bank 0 to bank 3
    // T2 read: smem[128], ..., smem[135] -> bank 0 to bank 3
    // T3 read: smem[192], ..., smem[199] -> bank 0 to bank 3
    // T4 read: smem[256], ..., smem[263] -> bank 0 to bank 3
    // T5 read: smem[320], ..., smem[327] -> bank 0 to bank 3
    // T6 read: smem[384], ..., smem[391] -> bank 0 to bank 3
    // T7 read: smem[448], ..., smem[455] -> bank 0 to bank 3
    const uint32_t address =
        cvta_to_shared_u32(smem) + sizeof(uint16_t) * ((tid%8) * 8);

    for (uint32_t i = tid; i < N; i += blockDim.x)
    {
        smem[i] = i;
    }
    __syncthreads();

    uint32_t frag[4];
    ldmatrix_sync_aligned_m8n8_x1_b16(
        frag[0], frag[1], frag[2], frag[3], address);

    __syncthreads();

    uint16_t number1 = static_cast<uint16_t>(frag[0] & 0xFFFF);
    uint16_t number2 = static_cast<uint16_t>((frag[0] >> 16) & 0xFFFF);
    uint16_t number3 = static_cast<uint16_t>(frag[1] & 0xFFFF);
    uint16_t number4 = static_cast<uint16_t>((frag[1] >> 16) & 0xFFFF);
    printf("%d  %d   %d  %d   %d   \n", (int)(smem[2 * tid]), (int)number1, (int)number2, (int)number3, (int)number4);
}

int main()
{
    uint16_t *d_value;
    cudaMalloc(&d_value, sizeof(uint16_t));
    ldmatrix_bank_conflict<<<1, 32>>>(d_value);
    cudaDeviceSynchronize();
    cudaFree(d_value);
    return 0;
}
# nvcc -o t164 t164.cu -arch=sm_89 -lineinfo
# compute-sanitizer ./t164
========= COMPUTE-SANITIZER
0  0   1  0   0
2  2   3  0   0
4  4   5  0   0
6  6   7  0   0
8  8   9  0   0
10  10   11  0   0
12  12   13  0   0
14  14   15  0   0
16  16   17  0   0
18  18   19  0   0
20  20   21  0   0
22  22   23  0   0
24  24   25  0   0
26  26   27  0   0
28  28   29  0   0
30  30   31  0   0
32  32   33  0   0
34  34   35  0   0
36  36   37  0   0
38  38   39  0   0
40  40   41  0   0
42  42   43  0   0
44  44   45  0   0
46  46   47  0   0
48  48   49  0   0
50  50   51  0   0
52  52   53  0   0
54  54   55  0   0
56  56   57  0   0
58  58   59  0   0
60  60   61  0   0
62  62   63  0   0
========= ERROR SUMMARY: 0 errors
#

I hate to be a language lawyer, but to first order: in an expression an integer type narrower than int is widened to int, so there is an implicit cast. This falls under chapters 4 “standard conversions” and 5 “expressions” of the ISO-C++11 standard. On all platforms currently supported by CUDA int is a 32-bit type and uint16_t is typedefed via cstdint to unsigned short int which is a 16-bit type. int is therefore wider and of higher conversion rank than uint16_t causing an implicit conversion from uint16_t to int to be applied and thus delivering the desired result. We may wish to make that conversion explicit, of course, but it is not strictly necessary. On a platform with 16-bit int (e.g. 8-bit and 16-bit microprocessors) the situation would naturally be different.

It is a best practice (one I frequently ignore for specific-sized integer types) to always use matching printf() format specifiers for all arguments. For specific-sized integer types these are pre-defined macros in the header file cinttypes, and for uint16_t that would be PRIu16, which on any platform supported by CUDA is exceedingly likely to be defined as "u". Example:

#include <cstdio> // printf
#include <cstdint> // uint16_t
#include <cinttypes> // PRIu16

int main (void)
{
    uint16_t a = 5;
    printf ("The value of a is: %" PRIu16 "\n", a);
    return 0;
}

Noted. Skip my item 2.

Thank you all very much! I greatly appreciate it!

test.cu
#include <iostream>
#include <cuda_fp16.h>
#include <mma.h>
#include <cuda_fp16.h>
#include <mma.h>
#include <cuda.h>

#define BLOCK_SIZE 16

using namespace nvcuda;

__device__ void Load_SmemA(int lane_id,
                           int m, int n,
                           int d,
                           int step,
                           half *value,
                           half *smem_value)
{
    int tt = (m * n + blockDim.x * blockDim.y - 1) / (blockDim.x * blockDim.y);

    for (int k = 0; k < tt; k++)
    {
        int temp = lane_id + k * blockDim.x * blockDim.y;
        int x = temp / n, y = temp % n;
        smem_value[x * n + y] = value[x * d + y + step * n];
        __syncthreads();
    }
    __syncthreads();
}

__device__ void Load_SmemB(int lane_id,
                           int m, int n,
                           int d,
                           int step,
                           half *value,
                           half *smem_value)
{
    int tt = (m * n + blockDim.x * blockDim.y - 1) / (blockDim.x * blockDim.y);
    for (int k = 0; k < tt; k++)
    {
        int temp = lane_id + k * blockDim.x * blockDim.y;
        int x = temp / n, y = temp % n;
        smem_value[x * n + y] = value[(x + step * m) * d + y];
    }
}

__device__ __forceinline__ uint32_t cvta_to_shared_u32(const void *pointer)
{
    uint32_t address;
    asm("{\n\t"
        "  .reg .u64 u64addr;\n\t"
        "  cvta.to.shared.u64 u64addr, %1;\n\t"
        "  cvt.u32.u64 %0, u64addr;\n\t"
        "}"
        : "=r"(address)
        : "l"(pointer));
    return address;
}
__device__ void Load_FragA(int lane_id,
                           int group_warp,
                           half *value,
                           uint32_t *frag)
{
    uint32_t shmem_A_lane_addr = __cvta_generic_to_shared(value) + ((lane_id % 16) * 8) * sizeof(half);

    asm(
        "ldmatrix.sync.aligned.x2.m8n8.shared.b16 "
        "{%0,%1}, [%2]; "
        : "=r"(frag[0]), "=r"(frag[1])
        : "r"(shmem_A_lane_addr));

    __syncthreads();

    // Extracting the first half value
    half firstHalf = *reinterpret_cast<half *>(&frag[0]);
    uint32_t shiftedValue = frag[0] >> 16;
    half secondHalf = *reinterpret_cast<half *>(&shiftedValue);
    half thirdHalf = *reinterpret_cast<half *>(&frag[1]);
    shiftedValue = frag[1] >> 16;
    half forthHalf = *reinterpret_cast<half *>(&shiftedValue);
    printf("%f     %f     %f    %f \n", __half2float(firstHalf), __half2float(secondHalf), __half2float(thirdHalf), __half2float(forthHalf));
}

__device__ void Store_SmemC(int lane_id,
                            int group_warp,
                            half *value,
                            uint32_t *frag)
{
    uint32_t shmem_A_lane_addr = __cvta_generic_to_shared(value) + ((lane_id % 8) * 16) * sizeof(half);

    asm(
        "stmatrix.sync.aligned.x1.m8n8.shared.b16 "
        "{%0}, [%1]; "
        : "=r"(shmem_A_lane_addr)
        : "r"(frag[0]));
}

// m16n8k8
__global__ void matrixMultiply(half *A, half *B, half *C,
                               int m, int n, int k)
{

    int lane_id = (threadIdx.y * blockDim.x + threadIdx.x) % 32;
    int group_warp = (threadIdx.y * blockDim.x + threadIdx.x) / 32;
    __shared__ half smem_A[16 * 8];
    __shared__ half smem_B[16 * 8];
    __shared__ half smem_C[16 * 16];

    int threadId = threadIdx.y * blockDim.x + threadIdx.x;

    __shared__ uint16_t test_int[16 * 8];
    test_int[threadId] = threadId;

    uint32_t fragA[2];
    uint32_t fragB[2];
    uint32_t acc[4] = {0};

    __syncthreads();
    for (int i = 0; i < 4; i++)
    {
        Load_SmemA(threadId, 16, 8, k, 2, A, smem_A);
        __syncthreads();
        Load_FragA(lane_id, group_warp, smem_A, fragA);
        __syncthreads();
       //break;
    }
}

int main()
{
    // int N = 32;
    int m = 16, n = 16, k = 128;

    half *h_A = new half[m * k];
    half *h_B = new half[k * n];
    half *h_C = new half[m * n];

    for (int i = 0; i < m; i++)
    {
        for (int j = 0; j < k; j++)
        {
            h_A[i * k + j] = (half)(i);
        }
    }
    for (int i = 0; i < k; i++)
    {
        for (int j = 0; j < n; j++)
        {
            h_B[i * n + j] = (half)(j);
        }
    }
    for (int i = 0; i < m; i++)
    {
        for (int j = 0; j < n; j++)
        {
            h_C[i * n + j] = (half)(0.0f);
        }
    }

    half *d_A, *d_B;

    half *d_C;
    cudaMalloc((void **)&d_A, m * k * sizeof(half));
    cudaMalloc((void **)&d_B, k * n * sizeof(half));
    cudaMalloc((void **)&d_C, m * n * sizeof(half));
    cudaMemcpy(d_A, h_A, m * k * sizeof(half), cudaMemcpyHostToDevice);
    cudaMemcpy(d_B, h_B, k * n * sizeof(half), cudaMemcpyHostToDevice);

    dim3 blockDim(1, 32);
    dim3 gridDim(1, 1);

    matrixMultiply<<<gridDim, blockDim>>>(d_A, d_B, d_C, m, n, k);
    cudaDeviceSynchronize();
    cudaMemcpy(h_C, d_C, m * n * sizeof(half), cudaMemcpyDeviceToHost);

    cudaFree(d_A);
    cudaFree(d_B);
    cudaFree(d_C);
    delete[] h_A;
    delete[] h_B;
    delete[] h_C;

    return 0;
}

When I actually use ldmatrix , I encounter a very strange issue. Here is the specific code. I cannot print the correct result when using ldmatrix inside a for loop, but when I add a break statement below the for loop, the result is correct. I don’t understand this.
Here is the compilation process :nvcc -arch=sm_75 -o test_ ./test.cu

It appears to me to be a code generation issue, i.e. a compiler problem. If we leave the break; statement commented out, and change the loop extent, I see varying behavior:

If I make it:

for (int i = 0; i < 2; i++)

I see expected printout. If I make it:

for (int i = 0; i < 3; i++)

I see all zeros. cc8.9 on CUDA 12.2 I can’t explain why that would be. compute-sanitizer reports no issues in any case. I also note that if I compile with -G with a loop extent of 3, I get expected output, rather than all zeros.

My suggestion is to retest on CUDA 12.4 (i.e. latest CUDA version available at the moment) and if the issue persists, then file a bug.

If you do file a bug, my suggestion is to strip out anything unnecessary in the code. The compiler warns that there are various unused variables, and there are even unused functions.

The posted code has the break inside the for-loop, at the end. That is equivalent to writing

for (int i = 0; i < 1; i++)

For a quick workaround, try adding #prama unroll 4 in the line immediately prior to the for-loop. With that in place I am getting the desired output with CUDA 12.3 for an sm_75 target.

Using the code as posted, turning off ptxas optimizations in isolation does not impact the observed behavior, but when turning off all optimizations the desired output is achieved. This suggests that if there is a compiler issue, it is in the LLVM-derived NVVM portion of the compiler, i.e. the part that generates PTX.

I have not reviewed the code to ensure that no undefined C++ behavior is invoked and that there are sufficient synchronization / barrier primitives being used. That should happen before filing a bug against the compiler.

I think the problem hasn’t been completely resolved even with #pragma unroll . Here’s the specific code for matrix multiplication. When k = 128 and 1280 , the result is correct. However, when k = 12800 , the output is all zeros. The debugging outputs in the error message are also all zeros. I’m very confused about this and I hope to get your help.

#include <iostream>
#include <cuda_fp16.h>
#include <mma.h>
#include <cuda_fp16.h>
#include <mma.h>
#include <cuda.h>

#define BLOCK_SIZE 16

using namespace nvcuda;

__device__ void Load_SmemA(int lane_id,
                           int m, int n,
                           int d,
                           int step,
                           half *value,
                           half *smem_value)
{
    int tt = (m * n + blockDim.x * blockDim.y - 1) / (blockDim.x * blockDim.y);

    for (int k = 0; k < tt; k++)
    {
        int temp = lane_id + k * blockDim.x * blockDim.y;
        int x = temp / n, y = temp % n;
        smem_value[x * n + y] = value[x * d + y + step * n];
        __syncthreads();
    }
    __syncthreads();
}

__device__ void Load_SmemB(int lane_id,
                           int m, int n,
                           int d,
                           int step,
                           half *value,
                           half *smem_value)
{
    int tt = (m * n + blockDim.x * blockDim.y - 1) / (blockDim.x * blockDim.y);
    for (int k = 0; k < tt; k++)
    {
        int temp = lane_id + k * blockDim.x * blockDim.y;
        int x = temp / n, y = temp % n;
        smem_value[x * n + y] = value[(x + step * m) * d + y];
    }
}

__device__ __forceinline__ uint32_t cvta_to_shared_u32(const void *pointer)
{
    uint32_t address;
    asm("{\n\t"
        "  .reg .u64 u64addr;\n\t"
        "  cvta.to.shared.u64 u64addr, %1;\n\t"
        "  cvt.u32.u64 %0, u64addr;\n\t"
        "}"
        : "=r"(address)
        : "l"(pointer));
    return address;
}
__device__ void Load_FragA(int lane_id,
                           int group_warp,
                           half *value,
                           uint32_t *frag)
{
    // uint32_t shmem_A_lane_addr = __cvta_generic_to_shared(value) + ((lane_id % 16) * 8) * sizeof(uint16_t);
    uint32_t shmem_A_lane_addr = __cvta_generic_to_shared(value) + ((lane_id % 16) * 8) * sizeof(half);
    // printf("%.2f   \n", __half2float(value[(lane_id % 16) * 8]));
    asm(
        "ldmatrix.sync.aligned.x2.m8n8.shared.b16 "
        "{%0,%1}, [%2]; "
        : "=r"(frag[0]), "=r"(frag[1])
        : "r"(shmem_A_lane_addr));
}

__device__ void Load_FragB(int lane_id,
                           int group_warp,
                           half *value,
                           uint32_t *frag)
{

    if (group_warp == 0)
    {
        uint32_t shmem_A_lane_addr = __cvta_generic_to_shared(value) + ((lane_id % 8) * 16 + 8 * group_warp) * sizeof(half);

        asm(
            "ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 "
            "{%0}, [%1]; "
            : "=r"(frag[group_warp])
            : "r"(shmem_A_lane_addr));
    }
    else
    {
        uint32_t shmem_A_lane_addr = __cvta_generic_to_shared(value) + ((lane_id % 8) * 16 + 8 * group_warp) * sizeof(half);

        asm(
            "ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 "
            "{%0}, [%1]; "
            : "=r"(frag[group_warp])
            : "r"(shmem_A_lane_addr));
    }
}

__device__ void compute(int lane_id,
                        int group_warp,
                        uint32_t *fragA,
                        uint32_t *fragB,
                        uint32_t *fragC)
{
    if (group_warp == 0)
    {

        asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
            " {%0, %1},"
            " {%2, %3},"
            " {%4},"
            " {%5, %6};\n"
            : "=r"(fragC[0]), "=r"(fragC[1])
            : "r"(fragA[0]), "r"(fragA[1]),
              "r"(fragB[0]),
              "r"(fragC[0]), "r"(fragC[1]));
    }
    else
    {

        asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
            " {%0, %1},"
            " {%2, %3},"
            " {%4},"
            " {%5, %6};\n"
            : "=r"(fragC[2]), "=r"(fragC[3])
            : "r"(fragA[0]), "r"(fragA[1]),
              "r"(fragB[1]),
              "r"(fragC[2]), "r"(fragC[3]));
    }
}

__device__ void Store_SmemC(int lane_id,
                            int group_warp,
                            half *value,
                            uint32_t *frag)
{
    int x = lane_id / 4, y = lane_id % 4;

    *(uint32_t *)(&value[x * 16 + 2 * y + 8 * group_warp]) = frag[group_warp * 2];
    *(uint32_t *)(&value[(x + 8) * 16 + 2 * y + 8 * group_warp]) = frag[group_warp * 2 + 1];
}

// m16n8k8
//  CUDA核函数,利用Tensor Core执行矩阵乘法
template <int m, int n, int k>
__global__ void matrixMultiply(half *A, half *B, half *C)
{

    int lane_id = (threadIdx.y * blockDim.x + threadIdx.x) % 32;
    int group_warp = (threadIdx.y * blockDim.x + threadIdx.x) / 32;
    __shared__ half smem_A[16 * 8];
    __shared__ half smem_B[16 * 8];
    __shared__ half smem_C[16 * 16];

    int threadId = threadIdx.y * blockDim.x + threadIdx.x;

    uint32_t fragA[2];
    uint32_t fragB[2];
    uint32_t acc[4] = {0};

#pragma unroll
    for (int i = 0; i < k / 8; i++)
    {
        Load_SmemA(threadId, 16, 8, k, i, A, smem_A);
        Load_SmemB(threadId, 8, 16, n, i, B, smem_B);
        __syncthreads();

        Load_FragA(lane_id, group_warp, smem_A, fragA);
        Load_FragB(lane_id, group_warp, smem_B, fragB);

        compute(lane_id, group_warp, fragA, fragB, acc);
        if (threadId == 15 && i == 0)
        {
            half firstHalf = *reinterpret_cast<half *>(&fragA[1]);
            uint32_t shiftedValue = fragA[1] >> 16;
            half secondHalf = *reinterpret_cast<half *>(&shiftedValue);
            printf("%f     %f      \n", __half2float(firstHalf), __half2float(secondHalf));
        }
        __syncthreads();
    }

    Store_SmemC(lane_id, group_warp, smem_C, acc);
    __syncthreads();
    if (threadId == 0)
    {
        for (int ii = 0; ii < 16; ii++)
        {
            for (int jj = 0; jj < 16; jj++)
            {
                printf("%f  ", __half2float(smem_C[ii * 16 + jj]));
            }
            printf("\n");
        }
    }
}

int main()
{
    // int k = 128  1280  12800
    int m = 16, n = 16, k = 12800;

    half *h_A = new half[m * k];
    half *h_B = new half[k * n];
    half *h_C = new half[m * n];

    for (int i = 0; i < m; i++)
    {
        for (int j = 0; j < k; j++)
        {
            h_A[i * k + j] = (half)(i * 0.1);
        }
    }
    for (int i = 0; i < k; i++)
    {
        for (int j = 0; j < n; j++)
        {
            h_B[i * n + j] = (half)(j * 0.1);
        }
    }
    for (int i = 0; i < m; i++)
    {
        for (int j = 0; j < n; j++)
        {
            h_C[i * n + j] = (half)(0.0f);
        }
    }

    half *d_A, *d_B;

    half *d_C;
    cudaMalloc((void **)&d_A, m * k * sizeof(half));
    cudaMalloc((void **)&d_B, k * n * sizeof(half));
    cudaMalloc((void **)&d_C, m * n * sizeof(half));
    cudaMemcpy(d_A, h_A, m * k * sizeof(half), cudaMemcpyHostToDevice);
    cudaMemcpy(d_B, h_B, k * n * sizeof(half), cudaMemcpyHostToDevice);

    dim3 blockDim(1, 64);
    dim3 gridDim(1, 1);

    matrixMultiply<16, 16, 12800><<<gridDim, blockDim>>>(d_A, d_B, d_C);
    cudaDeviceSynchronize();
    cudaMemcpy(h_C, d_C, m * n * sizeof(half), cudaMemcpyDeviceToHost);

    cudaFree(d_A);
    cudaFree(d_B);
    cudaFree(d_C);
    delete[] h_A;
    delete[] h_B;
    delete[] h_C;

    return 0;
}```

Sorry, I am not familiar with the GPU functionality exercised here and have no further insights. You may wish to file a bug report with NVIDIA.

Well, thank you very much for your reply