Warp shuffle instruction not working as expected

I took the source code from the following link:

and, refactored it using registers and shuffle instructions.

However, the output is not correct:

Result Matrix:
    28      0     36      0
    38      0     50      0
   108      0    148      0
   118      0    162      0

What am I doing incorrectly?

#include <iostream>
#include <iomanip>

using namespace std;

const int ROWS1 = 4;
const int COLS1 = 4;

const int ROWS2 = 4;
const int COLS2 = 4;

const int ROWS3 = ROWS1;
const int COLS3 = COLS2;

const int TILE_ROW_SIZE = 2;
const int TILE_COL_SIZE = 2;

#define IDX(tile_size, tile_i, relative_i) ((tile_size) * (tile_i) + (relative_i))


__global__ void MultiplyAsSumOuterProductOfVectors(int *A, int *B, int *C,
  int tile_row_size, int tile_col_size,
  int cols1, int rows1, int cols2)
{
  int tile_i = blockIdx.y;
  int tile_j = blockIdx.x;

  int cell_i = threadIdx.y;
  int cell_j = threadIdx.x;

  // column vectors of matrices A and B are stored in registers before multiplication
  int regA[TILE_ROW_SIZE];
  int regB[TILE_ROW_SIZE];

  // the multiplication result is stored in shared memory
  __shared__ int shrdC[TILE_ROW_SIZE][TILE_COL_SIZE+1];
  shrdC[cell_i][cell_j] = 0;

  __syncthreads();

  for (int tile_r = 0; tile_r < cols2/tile_col_size; tile_r++)
  {
    regA[cell_i] = A[IDX(cols1, tile_i*TILE_ROW_SIZE+cell_i, tile_r*TILE_COL_SIZE+cell_j)];
    regB[cell_j] = B[IDX(cols2, tile_r*TILE_ROW_SIZE+cell_i, tile_j*TILE_COL_SIZE+cell_j)];

    __syncthreads();

    for (int cell_r = 0; cell_r < TILE_ROW_SIZE; cell_r++)
    {
      // the kernel uses shuffle instructions to replace the contents of registers
      int valA = __shfl_sync(0xffffffff, regA[cell_r], cell_i);
      int valB = __shfl_sync(0xffffffff, regB[cell_r], cell_j);
      shrdC[cell_i][cell_j] += valA * valB;
    }

    __syncthreads();
  }

  int c_i = tile_i * TILE_ROW_SIZE + cell_i;
  int c_j = tile_j * TILE_COL_SIZE + cell_j;

  if (c_i < rows1 && c_j < cols2)
  {
    C[IDX(COLS3, c_i, c_j)] = shrdC[cell_i][cell_j];
  }
}


void printMatrix(int *mat, int rr, int cc) {
  for (int i = 0; i < rr; i++) {
    for (int j = 0; j < cc; j++) {
      cout << setw(6) << mat[i * cc + j] << " ";
    }
    cout << endl;
  }
  cout << endl;
}

void allocateMatrix(int *&a, int rows, int cols) {
  a = new int[rows * cols];
}

void freeMatrix(int *a) {
  delete[] a;
}

void initMatrix(int *mat, int rr, int cc) {
  int init = 1;
  for (int i = 0; i < rr; i++) {
    for (int j = 0; j < cc; j++) {
      mat[i * cc + j] = init++;
    }
  }
}

void initMatrixZero(int *mat, int rr, int cc) {
  for (int i = 0; i < rr; i++) {
    for (int j = 0; j < cc; j++) {
      mat[i * cc + j] = 0;
    }
  }
}

// the kernel should use 32 thread blocks
// however for this experiment we are using 2.
int main() {
  int *A, *B, *C;
  int *d_A, *d_B, *d_C;

  allocateMatrix(A, ROWS1, COLS1);
  initMatrix(A, ROWS1, COLS1);

  allocateMatrix(B, ROWS2, COLS2);
  initMatrix(B, ROWS2, COLS2);

  allocateMatrix(C, ROWS3, COLS3);
  initMatrixZero(C, ROWS3, COLS3);

  cudaMalloc((void **)&d_A, ROWS1 * COLS1 * sizeof(int));
  cudaMalloc((void **)&d_B, ROWS2 * COLS2 * sizeof(int));
  cudaMalloc((void **)&d_C, ROWS3 * COLS3 * sizeof(int));
  cudaMemset(d_C, 0, ROWS3 * COLS3 * sizeof(int));


  cudaMemcpy(d_A, A, ROWS1 * COLS1 * sizeof(int), cudaMemcpyHostToDevice);
  cudaMemcpy(d_B, B, ROWS2 * COLS2 * sizeof(int), cudaMemcpyHostToDevice);
  cudaMemcpy(d_C, C, ROWS3 * COLS3 * sizeof(int), cudaMemcpyHostToDevice);


  dim3 blocks(COLS3/TILE_COL_SIZE, ROWS3/TILE_ROW_SIZE);
  dim3 threads(TILE_COL_SIZE, TILE_ROW_SIZE);
  MultiplyAsSumOuterProductOfVectors<<<blocks, threads>>>(d_A, d_B, d_C, TILE_ROW_SIZE, TILE_COL_SIZE, COLS1, ROWS1, COLS2);


  cudaMemcpy(C, d_C, ROWS3 * COLS3 * sizeof(int), cudaMemcpyDeviceToHost);

  cout << "Result Matrix:" << endl;
  printMatrix(C, ROWS3, COLS3);

  cudaFree(d_A);
  cudaFree(d_B);
  cudaFree(d_C);

  freeMatrix(A);
  freeMatrix(B);
  freeMatrix(C);

  return 0;
}

I see a variety of problems or things that look strange to me. Let’s consider the very first pass through your code. This statement:

int regA[TILE_ROW_SIZE];

does not initialize any values in that array (just like it does not initialize any values in C or C++). Furthermore, we should keep in mind that that is not shared memory; each thread has its own local copy of regA[].

on the first pass through the tile for-loop, this statement:

regA[cell_i] = A[IDX(cols1, tile_i*TILE_ROW_SIZE+cell_i, tile_r*TILE_COL_SIZE+cell_j)];

loads exactly one value into a particular thread’s copy of regA. For the thread with cell_i equal to zero, for example, that statement will populate regA[0] For that particular thread, other values in regA are uninitialized at the moment.

right after that we get to this nested for-loop:

for (int cell_r = 0; cell_r < TILE_ROW_SIZE; cell_r++)
{
  // the kernel uses shuffle instructions to replace the contents of registers
  int valA = __shfl_sync(0xffffffff, regA[cell_r], cell_i);
  int valB = __shfl_sync(0xffffffff, regB[cell_r], cell_j);
  shrdC[cell_i][cell_j] += valA * valB;
}

So the cell_r index variable in this case will index over the tile, i.e. it will take on values of 0 and 1 in this case. Given that, what do you suppose happens here for the thread whose cell_i is zero:

  int valA = __shfl_sync(0xffffffff, regA[cell_r], cell_i);

on the pass of the loop where cell_r is zero, it is referencing regA[0], which is OK. But on the pass of this nested for loop where cell_r is 1, it is referencing regA[1] which is uninitialized at this moment, for the thread whose cell_i is zero.

That is undefined behavior in C and C++.

A few other things to point out:

  • given that your outer for-loop only loads one value for each of regA and regB per pass, it is either not necessary for you to have an array for regA or regB (a single scalar variable for each would suffice) or else you don’t understand what is going on, and I’m not able to explain it since you’ve offered no explanation for what is “expected” other than code.
  • your use of the shared array to store temporary results is also superfluous. Each thread is only carrying a single sum-result for a given point in the output matrix; it would be more sensible just to use a single scalar variable to carry this running sum per thread.

… or else you don’t understand what is going on, and I’m not able to explain it since you’ve offered no explanation for what is “expected” other than code.

  1. I want to use registers for keeping component vectors before multiplication.
  2. I want to use shared memory for keeping intermediate multiplication results for summing them up.
  3. I want to use warp shuffle instructions for loading vectors into registers.

One more thing about which I am not sure: warps work on 32 blocks, while I am using 2x2 blocks. Does that create any problem?

In CUDA, we generally prefer that adjacent threads in a warp load adjacent data from memory, for performance reasons. This falls under the general topic area of “coalescing” in CUDA.

For reasonable size problems, this means that adjacent threads will read along a row in a 2D matrix (as it is normally laid out in C++). We already know from your previous matrix-multiply codes that when dealing with the A matrix in the equation C = AxB, that we do indeed need for a thread to index along a row, in order for it to retrieve the elements of A that it needs to compute a single output element (i.e. to compute a single vector dot-product). This might possibly be amenable, in a limited way, to using warp-shuffle as a way to exchange A “vector data” between threads, perhaps eliminating the need for a tile storage for A such as the shared memory tile.

However for the B matrix (or tile) each thread needs to collect values from a column vector in B, i.e. each thread needs to index along a column. If we stick to our desire to load in a coalesced fashion, the warp shuffle does not immediately present a way, like it does for the A matrix, to allow for useful exchange of B vector elements between threads in a warp. The threads in a warp would be reading row vectors, not column vectors.

It might be possible to use a methodology similar to a register-based transpose to address this aspect of using values from the B tile, but this would reduce us from using a full complement of threads in the threadblock to only using a single “row” of threads. That doesn’t look very attractive to me, at first glance.

But if you did that, you could probably eliminate the use of shared memory altogether, for, say, 32x32 tiles, using only a single warp of 32 threads. This would address your items 1 and 3. Item 2 doesn’t make much sense to me, and doing that is not something that I would recommend for sensible or performant CUDA code, therefore I personally would not spend any time on that, even though it is basically trivial (and you have already implemented it, anyway).

For the case of your “toy” problem with 2x2 threadblocks, all the threads in that tile/threadblock would belong to a single warp, so it should be possible to come up with a shuffling pattern that would allow you to eliminate the use of shared memory tiles there.

If time permits I may try to write some code, but don’t have anything further to share at this time. Working on a 2x2 problem is useful if the concepts employed are readily extensible to larger sizes, but they would not be extensible directly beyond a 4x4 threadblock size, which isn’t interesting to me, and I doubt would be interesting from a performance perspective.

I tried a simple experiment based on my previous comments: the low-hanging fruit is to try using warp shuffle in place of the A shared tile. I used the code presented in the programming guide as a starting point, selecting 32x32 tiles. Interestingly the performance is noticeably slower on a L4 GPU with CUDA 12.2.1 for reasonably large matrix sizes (4096x4096).

Here’s what I came up with:

# cat t15.cu
#include <iostream>
#include <cstdlib>
#include <time.h>
#include <sys/time.h>
#define USECPSEC 1000000ULL

unsigned long long dtime_usec(unsigned long long start=0){

  timeval tv;
  gettimeofday(&tv, 0);
  return ((tv.tv_sec*USECPSEC)+tv.tv_usec)-start;
}
#define RNG 3
// Thread block size
#define BLOCK_SIZE 32
// Matrices are stored in row-major order:
// M(row, col) = *(M.elements + row * M.stride + col)
typedef struct {
    int width;
    int height;
    int stride;
    float* elements;
} Matrix;
// Get a matrix element
__device__ float GetElement(const Matrix A, int row, int col)
{
    return A.elements[row * A.stride + col];
}
// Set a matrix element
__device__ void SetElement(Matrix A, int row, int col,
                           float value)
{
    A.elements[row * A.stride + col] = value;
}
// Get the BLOCK_SIZExBLOCK_SIZE sub-matrix Asub of A that is
// located col sub-matrices to the right and row sub-matrices down
// from the upper-left corner of A
 __device__ Matrix GetSubMatrix(Matrix A, int row, int col)
{
    Matrix Asub;
    Asub.width    = BLOCK_SIZE;
    Asub.height   = BLOCK_SIZE;
    Asub.stride   = A.stride;
    Asub.elements = &A.elements[A.stride * BLOCK_SIZE * row
                                         + BLOCK_SIZE * col];
    return Asub;
}

// Matrix multiplication kernel called by MatMul()
 __global__ void MatMulKernel(Matrix A, Matrix B, Matrix C)
{
    // Block row and column
    int blockRow = blockIdx.y;
    int blockCol = blockIdx.x;
    // Each thread block computes one sub-matrix Csub of C
    Matrix Csub = GetSubMatrix(C, blockRow, blockCol);
    // Each thread computes one element of Csub
    // by accumulating results into Cvalue
    float Cvalue = 0;
    // Thread row and column within Csub
    int row = threadIdx.y;
    int col = threadIdx.x;
    // Loop over all the sub-matrices of A and B that are
    // required to compute Csub
    // Multiply each pair of sub-matrices together
    // and accumulate the results
    for (int m = 0; m < (A.width / BLOCK_SIZE); ++m) {
        // Get sub-matrix Asub of A
        Matrix Asub = GetSubMatrix(A, blockRow, m);
        // Get sub-matrix Bsub of B
        Matrix Bsub = GetSubMatrix(B, m, blockCol);
        // Shared memory used to store Asub and Bsub respectively
        __shared__ float Bs[BLOCK_SIZE][BLOCK_SIZE];
        __shared__ float As[BLOCK_SIZE][BLOCK_SIZE];
        // Load Asub and Bsub from device memory to shared memory
        // Each thread loads one element of each sub-matrix
        As[row][col] = GetElement(Asub, row, col);
        Bs[row][col] = GetElement(Bsub, row, col);
        // Synchronize to make sure the sub-matrices are loaded
        // before starting the computation
        __syncthreads();
        // Multiply Asub and Bsub together
        for (int e = 0; e < BLOCK_SIZE; ++e)
            Cvalue += As[row][e] * Bs[e][col];
        // Synchronize to make sure that the preceding
        // computation is done before loading two new
        // sub-matrices of A and B in the next iteration
        __syncthreads();
    }
    // Write Csub to device memory
    // Each thread writes one element
    SetElement(Csub, row, col, Cvalue);
}


 __global__ void MatMulKernel_shflA(Matrix A, Matrix B, Matrix C)
{
    // Block row and column
    int blockRow = blockIdx.y;
    int blockCol = blockIdx.x;
    // Each thread block computes one sub-matrix Csub of C
    Matrix Csub = GetSubMatrix(C, blockRow, blockCol);
    // Each thread computes one element of Csub
    // by accumulating results into Cvalue
    float Cvalue = 0;
    // Thread row and column within Csub
    int row = threadIdx.y;
    int col = threadIdx.x;
    // Loop over all the sub-matrices of A and B that are
    // required to compute Csub
    // Multiply each pair of sub-matrices together
    // and accumulate the results
    for (int m = 0; m < (A.width / BLOCK_SIZE); ++m) {
        // Get sub-matrix Asub of A
        Matrix Asub = GetSubMatrix(A, blockRow, m);
        // Get sub-matrix Bsub of B
        Matrix Bsub = GetSubMatrix(B, m, blockCol);
        // Shared memory used to store Asub and Bsub respectively
        __shared__ float Bs[BLOCK_SIZE][BLOCK_SIZE];
        float my_A = GetElement(Asub, row, col);
        Bs[row][col] = GetElement(Bsub, row, col);
        // Synchronize to make sure the sub-matrices are loaded
        // before starting the computation
        __syncthreads();
        // Multiply Asub and Bsub together
        for (int e = 0; e < BLOCK_SIZE; ++e)
            Cvalue += __shfl_sync(0xFFFFFFFF, my_A, e) * Bs[e][col];
        // Synchronize to make sure that the preceding
        // computation is done before loading two new
        // sub-matrices of A and B in the next iteration
        __syncthreads();
    }
    // Write Csub to device memory
    // Each thread writes one element
    SetElement(Csub, row, col, Cvalue);
}

// Matrix multiplication - Host code
// Matrix dimensions are assumed to be multiples of BLOCK_SIZE
void MatMul(const Matrix A, const Matrix B, Matrix C)
{
    // Load A and B to device memory
    Matrix d_A;
    d_A.width = d_A.stride = A.width; d_A.height = A.height;
    size_t size = A.width * A.height * sizeof(float);
    cudaMalloc(&d_A.elements, size);
    cudaMemcpy(d_A.elements, A.elements, size,
               cudaMemcpyHostToDevice);
    Matrix d_B;
    d_B.width = d_B.stride = B.width; d_B.height = B.height;
    size = B.width * B.height * sizeof(float);
    cudaMalloc(&d_B.elements, size);
    cudaMemcpy(d_B.elements, B.elements, size,
    cudaMemcpyHostToDevice);
    Matrix h_C_shfl;
    h_C_shfl.width = h_C_shfl.stride = C.width; h_C_shfl.height = C.height;
    h_C_shfl.elements = new float[h_C_shfl.width*h_C_shfl.height];
    // Allocate C in device memory
    Matrix d_C;
    d_C.width = d_C.stride = C.width; d_C.height = C.height;
    size = C.width * C.height * sizeof(float);
    cudaMalloc(&d_C.elements, size);
    // Invoke kernel
    dim3 dimBlock(BLOCK_SIZE, BLOCK_SIZE);
    dim3 dimGrid(B.width / dimBlock.x, A.height / dimBlock.y);
    MatMulKernel<<<dimGrid, dimBlock>>>(d_A, d_B, d_C); // warm-up
    cudaDeviceSynchronize();
    unsigned long long dt = dtime_usec(0);
    MatMulKernel<<<dimGrid, dimBlock>>>(d_A, d_B, d_C);
    cudaDeviceSynchronize();
    dt = dtime_usec(dt);
    std::cout << "shared kernel time:  " << dt << "us" << std::endl;
    cudaMemcpy(C.elements, d_C.elements, size, cudaMemcpyDeviceToHost);
    MatMulKernel_shflA<<<dimGrid, dimBlock>>>(d_A, d_B, d_C); // warm-up
    cudaDeviceSynchronize();
    dt = dtime_usec(0);
    MatMulKernel_shflA<<<dimGrid, dimBlock>>>(d_A, d_B, d_C);
    cudaDeviceSynchronize();
    dt = dtime_usec(dt);
    std::cout << "shuffle kernel time: " << dt << "us" << std::endl;
    // Read C from device memory
    cudaMemcpy(h_C_shfl.elements, d_C.elements, size, cudaMemcpyDeviceToHost);
    for (int i = 0; i < h_C_shfl.width*h_C_shfl.height; i++) if (h_C_shfl.elements[i] != C.elements[i]) {std::cout << "mismatch at: " << i << " should be: " << C.elements[i] << " was: " << h_C_shfl.elements[i] << std::endl; return;}
    // Free device memory
    cudaFree(d_A.elements);
    cudaFree(d_B.elements);
    cudaFree(d_C.elements);
}

int main(){

  Matrix h_A, h_B, h_C;
  const int dim = 128*BLOCK_SIZE;
  h_A.width=h_A.stride=h_A.height=h_B.width=h_B.stride=h_B.height=h_C.width=h_C.stride=h_C.height=dim;
  h_A.elements = new float[dim*dim];
  h_B.elements = new float[dim*dim];
  h_C.elements = new float[dim*dim];
  for (int i = 0; i < dim*dim; i++) {
    h_A.elements[i] = rand()%RNG;
    h_B.elements[i] = rand()%RNG;}
  MatMul(h_A, h_B, h_C);
}
# nvcc -o t15 t15.cu -arch=sm_89
# ./t15
shared kernel time:  69830us
shuffle kernel time: 94572us
#

So this is only 1 datapoint, but the initial conclusion I have is that warp shuffle for this case doesn’t seem to be performance advantageous vs. using shared memory. I would encourage all readers not to draw sweeping conclusions from this one datapoint: there are certainly examples where refactoring a shared memory code to use warp shuffle can result in improved performance.

void MatMul(const Matrix A, const Matrix B, Matrix C)
{
    // ... ... ...
    // Invoke kernel
    dim3 dimBlock(BLOCK_SIZE, BLOCK_SIZE);
    dim3 dimGrid(B.width / dimBlock.x, A.height / dimBlock.y);
    MatMulKernel<<<dimGrid, dimBlock>>>(d_A, d_B, d_C); // warm-up
    // ... ... ...
	MatMulKernel<<<dimGrid, dimBlock>>>(d_A, d_B, d_C);
    // ... ... ...
    MatMulKernel_shflA<<<dimGrid, dimBlock>>>(d_A, d_B, d_C); // warm-up
    // ... ... ...
	MatMulKernel_shflA<<<dimGrid, dimBlock>>>(d_A, d_B, d_C);
    // ... ... ...
}

I have two questions in this regard:

  1. Why have you called the each kernel twice?
  2. Why does commenting out MatMulKernel give incorrect result? Why do we have to call both kernels?

For benchmarking purposes. The first time you run a function, it may take longer. I’m interested in the shorter measurement.

Because I am using the results from MatMulKernel to test the correctness of MatMulKernel_shflA. I expect the results to match. If you comment out one kernel, the "results’ are not going to match. This line of code is testing for equality between the two results sets:

   for (int i = 0; i < h_C_shfl.width*h_C_shfl.height; i++) if (h_C_shfl.elements[i] != C.elements[i]) ...

This topic was automatically closed 14 days after the last reply. New replies are no longer allowed.