Ampere 16x8x256 BMMA

The CUDA Toolkit docs show a 16x8x256 BMMA for Ampere here:

However both the ptx docs for bmma_sync as well as tensor core element tables and mma.h only state the 8x8x128 case.

I am wondering if the “bigger” versions are undocumented, or whether the table is wrong.

If it is relevant, I need to do a kind of GEMM style access pattern reduction. Given two arrays of 256 bit integers, I need to popc(a xor b) for all pairs in the arrays. A kind of outer product. There is a reduction operation along one of the vectors that can be fused inside the kernel, which is why I cannot use any premade GEMM routines. The full result matrix would be too big.

So, since my integer size is 256, the bigger BMMA would fit more naturally. Also, I would assume that the bigger routines are faster, which is why I am looking at them in the first place.

I would be very happy to see an example of the instruction in action.

The include/header files are the most relevant authority. My guess is there is some incorrect info in the docs. You may wish to file a bug to confirm/update.

2 Likes

The bug (3634929) is still being looked at. I’m not really sure what the outcome will be, if I had to guess it will be a doc update of some sort.

As I mentioned previously, I believe the mma.h header file is the authority for what you can do in CUDA C++. So I don’t see that this is exposed via CUDA C++ (yet).

If you wish to use PTX, the instruction does seem to be documented.

Here is what I would call a basic demonstrator for the 16x8x256 case. It seems to compile correctly under CUDA 11.4. I haven’t tested it yet. There may be bugs/YMMV. I haven’t tried to make the output striping in any way sensible, I’m just dumping the output to an array without much/any attention to order. My expectation is that this would produce a result indicating half of the maximum value per output element.

#include <mma.h>
#include <cuda_fp16.h>
#include <iostream>
#include <stdio.h>

__global__ void mma_16x8x256_binary(unsigned *out) {
    unsigned c[4] = {0, 0, 0, 0};
    unsigned d[4] = {0, 0, 0, 0};
    unsigned a[4] = {0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF};
    unsigned b[2] = {0x55555555, 0x55555555};
    unsigned const *A = reinterpret_cast<unsigned const *>(&a);
    unsigned const *B = reinterpret_cast<unsigned const *>(&b);
    unsigned const *C = reinterpret_cast<unsigned const *>(&c);
    unsigned *D = reinterpret_cast<unsigned *>(&d);
    asm(
      "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc "
      "{%0,%1,%2,%3}, {%4,%5, %6, %7}, {%8, %9}, {%10,%11,%12,%13};\n"
      : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
      :
        "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
        "r"(B[0]), "r"(B[1]),
        "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])
    );
    memcpy(out+threadIdx.x*4, D, 16);
}

int main() {
    unsigned* h_C = (unsigned*)malloc(16*8*sizeof(unsigned));
    unsigned* d_C;
    cudaMalloc(&d_C, 16*8*sizeof(unsigned));
    mma_16x8x256_binary<<<1, 32>>>(d_C);
    cudaDeviceSynchronize();
    cudaMemcpy(h_C, d_C, 16*8*sizeof(unsigned), cudaMemcpyDeviceToHost);
    for (int i = 0; i < 16; i++){
      for (int j = 0; j < 8; j++) std::cout << h_C[i*8+j] << " ";
      std::cout << std::endl;}
}

This requires compilation for at least ampere i.e. -arch=sm_80 or -arch=sm_86, or similar/newer.

(I’ve had a chance to test it now - it works as I expected. The “width” of the multiplication is 256 bits, so the maximum value is 256 per output element, and the code outputs all values of 128.)

2 Likes

Thank you for the feedback :) It is very valuable to have curious experts respond on those topics.

I also found the instruction in CUTLASS, so this should indeed be working as intended:

This topic was automatically closed 14 days after the last reply. New replies are no longer allowed.