Ampere 16x8x256 BMMA

The bug (3634929) is still being looked at. I’m not really sure what the outcome will be, if I had to guess it will be a doc update of some sort.

As I mentioned previously, I believe the mma.h header file is the authority for what you can do in CUDA C++. So I don’t see that this is exposed via CUDA C++ (yet).

If you wish to use PTX, the instruction does seem to be documented.

Here is what I would call a basic demonstrator for the 16x8x256 case. It seems to compile correctly under CUDA 11.4. I haven’t tested it yet. There may be bugs/YMMV. I haven’t tried to make the output striping in any way sensible, I’m just dumping the output to an array without much/any attention to order. My expectation is that this would produce a result indicating half of the maximum value per output element.

#include <mma.h>
#include <cuda_fp16.h>
#include <iostream>
#include <stdio.h>

__global__ void mma_16x8x256_binary(unsigned *out) {
    unsigned c[4] = {0, 0, 0, 0};
    unsigned d[4] = {0, 0, 0, 0};
    unsigned a[4] = {0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF};
    unsigned b[2] = {0x55555555, 0x55555555};
    unsigned const *A = reinterpret_cast<unsigned const *>(&a);
    unsigned const *B = reinterpret_cast<unsigned const *>(&b);
    unsigned const *C = reinterpret_cast<unsigned const *>(&c);
    unsigned *D = reinterpret_cast<unsigned *>(&d);
      "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc "
      "{%0,%1,%2,%3}, {%4,%5, %6, %7}, {%8, %9}, {%10,%11,%12,%13};\n"
      : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
        "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
        "r"(B[0]), "r"(B[1]),
        "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])
    memcpy(out+threadIdx.x*4, D, 16);

int main() {
    unsigned* h_C = (unsigned*)malloc(16*8*sizeof(unsigned));
    unsigned* d_C;
    cudaMalloc(&d_C, 16*8*sizeof(unsigned));
    mma_16x8x256_binary<<<1, 32>>>(d_C);
    cudaMemcpy(h_C, d_C, 16*8*sizeof(unsigned), cudaMemcpyDeviceToHost);
    for (int i = 0; i < 16; i++){
      for (int j = 0; j < 8; j++) std::cout << h_C[i*8+j] << " ";
      std::cout << std::endl;}

This requires compilation for at least ampere i.e. -arch=sm_80 or -arch=sm_86, or similar/newer.

(I’ve had a chance to test it now - it works as I expected. The “width” of the multiplication is 256 bits, so the maximum value is 256 per output element, and the code outputs all values of 128.)