How can I write a printMatrix kernel?

The given CUDA source code generates the following output:

user_name@cuda_server:~/$ nvcc print_matrix_kernel.cu -o exe
user_name@cuda_server:~/$ ./exe
Result Matrix:
     1      2      3      4
     5      6      7      8
     9     10     11     12
    13     14     15     16

3 4 7 8 9 10 13 14 1 2 5 6 11 12 15 16

How can I write the printMatrixGPU kernel to display a similar 4x4 formatting as generated by the CPU printMatrix()?

#include <iostream>
#include <iomanip>

using namespace std;

const int ROWS_Y = 4;//1024;
const int COLS_X = 4;//1024;
const int TILE_ROWS_Y = 2;//32;
const int TILE_COLS_X = 2;//32;

#define GX(tx, lx) ((tx) * (TILE_COLS_X) + (lx))
#define GY(ty, ly) ((ty) * (TILE_ROWS_Y) + (ly))
#define GID2(gx, gy) ((gy) * (COLS_X) + (gx))
#define GID4(tx, ty, lx, ly) ((GY(ty, ly)) * (COLS_X) + (GX(tx, lx)))
#define MOSAIC_ROWS_Y ((ROWS_Y) / (TILE_ROWS_Y))
#define MOSAIC_COLS_X ((COLS_X) / (TILE_COLS_X))
#define LID(lx, ly) ((ly)*(TILE_COLS_X)+(lx))

__global__ void printMatrixGPU(int *A, int rows, int cols)
{
    int ty = blockIdx.y;
    int tx = blockIdx.x;

    int ly = threadIdx.y;
    int lx = threadIdx.x; 

    int gy = GY(ty, ly);
    int gx = GX(tx, lx);

    int idx = GID2(gx, gy); 

    printf("%d ", A[idx]);

    if(gx == cols-1)
    {
        printf("\n");
    }
}

void printMatrix(int *mat, int rr, int cc) {
  for (int i = 0; i < rr; i++) {
    for (int j = 0; j < cc; j++) {
      cout << setw(6) << mat[i * cc + j] << " ";
    }
    cout << endl;
  }
  cout << endl;
}

void allocateMatrix(int *&a, int rows, int cols) {
  a = new int[rows * cols];
}

void freeMatrix(int *a) {
  delete[] a;
}

void initMatrix(int *mat, int rr, int cc) {
  int init = 1;
  for (int i = 0; i < rr; i++) {
    for (int j = 0; j < cc; j++) {
      mat[i * cc + j] = init++;
    }
  }
}

void initMatrixZero(int *mat, int rr, int cc) {
  for (int i = 0; i < rr; i++) {
    for (int j = 0; j < cc; j++) {
      mat[i * cc + j] = 0;
    }
  }
}

int main() {
  int *A;
  int *d_A;

  allocateMatrix(A, ROWS_Y, COLS_X);
  initMatrix(A, ROWS_Y, COLS_X);

  // Allocate device memory
  cudaMalloc((void **)&d_A, ROWS_Y * COLS_X * sizeof(int));

  // Copy input matrices from host to device
  cudaMemcpy(d_A, A, ROWS_Y * COLS_X * sizeof(int), cudaMemcpyHostToDevice);

  // Set grid and block dimensions
  dim3 gridSize(MOSAIC_COLS_X, MOSAIC_ROWS_Y);
  dim3 blockSize(TILE_COLS_X, TILE_ROWS_Y);

  printMatrixGPU<<<gridSize, blockSize>>>(d_A, ROWS_Y, COLS_X);


  // Print the result matrix
  cout << "Result Matrix:" << endl;
  printMatrix(A, ROWS_Y, COLS_X);

  // Free device memory
  cudaFree(d_A);

  // Free host memory
  freeMatrix(A);

  return 0;
}

a general process and possible methodology are covered in my response here.

Doesn’t work!

__global__ void printMatrixGPU(int *A, int rows, int cols)
{
    int ty = blockIdx.y;
    int tx = blockIdx.x;

    int ly = threadIdx.y;
    int lx = threadIdx.x; 

    int gy = GY(ty, ly);
    int gx = GX(tx, lx);

    int idx = GID2(gx, gy); 

    if (idx < rows*cols) printf("%d ", A[idx]);
    __syncthreads();                      
    if (!gx) printf("\n");       
    __syncthreads(); 
}

Output:

Result Matrix:
     1      2      3      4
     5      6      7      8
     9     10     11     12
    13     14     15     16

3 4 7 8 9 10 13 14 1 2 5 6 11 12 15 16

See if you can figure out how to write an in-kernel printf statement that just prints one line of your desired printout. Then write a loop around that statement. That’s what I did in the answer I linked, and your example that “Doesn’t work!” doesn’t have that loop, among other issues.

Me writing a fully working code example evidently did not work. So I’m not going to do it again. We can take a different approach to learning, perhaps: one step at a time.

See if you can figure out how to make a kernel that just prints this:

 1      2      3      4

Once you have mastered that, I can suggest a next step.

CUDA doesn’t provide any block-level ordering of execution. While it is possible to order execution among blocks, its quite tedious. My suggestion would be to solve this problem at a single block level. At least do that first, before tackling the multi-block case. Do as you wish, of course.

I have come up with this:

__global__ void printMatrixGPU(int *A, int rows, int cols)
{
    for(int i=0 ; i<rows ; i++)
    {
        if(threadIdx.y==0 && threadIdx.x==i)
            printf("\n");
        __syncthreads();

        if(threadIdx.y==i)
            printf("%d ", A[GID4(0,0, threadIdx.x, i)]);
        __syncthreads();
    }
}
Result Matrix:
     1      2      3      4
     5      6      7      8
     9     10     11     12
    13     14     15     16


1 2 3 4
5 6 7 8
9 10 11 12
13 14 15 16 

Well, you seem to be quite close to a solution. You really only want to print one newline character per line. So pick a single thread to do that. How do you pick a single thread? check for a specific (x,y) thread index. Something like this:

    if(threadIdx.y==0 && threadIdx.x==0)
        printf("\n");

This isn’t a valid boolean check for what you want to do:

threadIdx.x=i)

A boolean equality test requires two equals signs in C++

Oh, that was a typo!

Well, you seem to have something mostly working then.

1 Like

Now, the qustion is:

How can I use this printMatrixGPU() kernel to print matrices inside our __global__ void MatMulKernel(Matrix A, Matrix B, Matrix C) kernel?

printing from a GPU is not a high performance path, and is generally not something I would do in production code. Therefore I interpret most such requests as based on either learning or debugging code.

In both of those cases, in my opinion, performance is not an issue.

Therefore I would suggest one of 2 options:

  1. copy the code from your posting here (just the body of the kernel, not the whole kernel) into the place you want to print it.

  2. (what I would do: ) Just use a single thread to print out anything you need. You already know how to print out what you want from a single thread - you have demonstrated that knowledge already in your CPU code here in this posting.

    Use that same code from your kernel

    if ((threadIdx.x == 0) && (threadIdx.y == 0)){
    // your printf code here
    }
    

I want to print the shared memory for each iterartion of the loops.

I see. Try it.

You could add something like this to your code:

__host__ __device__ void printShMatrix(float *mat, int rr, int cc) {
  for (int i = 0; i < rr; i++) {
    for (int j = 0; j < cc; j++) {
      printf("%.2f ", mat[i*cc+j]);
    }
    printf("\n");
  }
  printf("\n");
}

Then in your matrix multiply code, where you want to print out the shared memory, do something like this:

if ((threadIdx.x == 0) && (threadIdx.y == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) printShMatrix(As[0], BLOCK_DIM, BLOCK_DIM);
1 Like

But, this doesn’t show the elements in rows and columns.

did you try it?

Here is an example.

# cat t15.cu
#include <iostream>
#include <cstdlib>
#include <time.h>
#include <sys/time.h>
#define USECPSEC 1000000ULL


__host__ __device__ void printShMatrix(float *mat, int rr, int cc) {
  for (int i = 0; i < rr; i++) {
    for (int j = 0; j < cc; j++) {
      printf("%.2f ", mat[i*cc+j]);
    }
    printf("\n");
  }
  printf("\n");
}

unsigned long long dtime_usec(unsigned long long start=0){

  timeval tv;
  gettimeofday(&tv, 0);
  return ((tv.tv_sec*USECPSEC)+tv.tv_usec)-start;
}
#define RNG 3
// Thread block size
#define BLOCK_SIZE 4
// Matrices are stored in row-major order:
// M(row, col) = *(M.elements + row * M.stride + col)
typedef struct {
    int width;
    int height;
    int stride;
    float* elements;
} Matrix;
// Get a matrix element
__device__ float GetElement(const Matrix A, int row, int col)
{
    return A.elements[row * A.stride + col];
}
// Set a matrix element
__device__ void SetElement(Matrix A, int row, int col,
                           float value)
{
    A.elements[row * A.stride + col] = value;
}
// Get the BLOCK_SIZExBLOCK_SIZE sub-matrix Asub of A that is
// located col sub-matrices to the right and row sub-matrices down
// from the upper-left corner of A
 __device__ Matrix GetSubMatrix(Matrix A, int row, int col)
{
    Matrix Asub;
    Asub.width    = BLOCK_SIZE;
    Asub.height   = BLOCK_SIZE;
    Asub.stride   = A.stride;
    Asub.elements = &A.elements[A.stride * BLOCK_SIZE * row
                                         + BLOCK_SIZE * col];
    return Asub;
}

// Matrix multiplication kernel called by MatMul()
 __global__ void MatMulKernel(Matrix A, Matrix B, Matrix C)
{
    // Block row and column
    int blockRow = blockIdx.y;
    int blockCol = blockIdx.x;
    // Each thread block computes one sub-matrix Csub of C
    Matrix Csub = GetSubMatrix(C, blockRow, blockCol);
    // Each thread computes one element of Csub
    // by accumulating results into Cvalue
    float Cvalue = 0;
    // Thread row and column within Csub
    int row = threadIdx.y;
    int col = threadIdx.x;
    // Loop over all the sub-matrices of A and B that are
    // required to compute Csub
    // Multiply each pair of sub-matrices together
    // and accumulate the results
    for (int m = 0; m < (A.width / BLOCK_SIZE); ++m) {
        // Get sub-matrix Asub of A
        Matrix Asub = GetSubMatrix(A, blockRow, m);
        // Get sub-matrix Bsub of B
        Matrix Bsub = GetSubMatrix(B, m, blockCol);
        // Shared memory used to store Asub and Bsub respectively
        __shared__ float Bs[BLOCK_SIZE][BLOCK_SIZE];
        __shared__ float As[BLOCK_SIZE][BLOCK_SIZE];
        // Load Asub and Bsub from device memory to shared memory
        // Each thread loads one element of each sub-matrix
        As[row][col] = GetElement(Asub, row, col);
        Bs[row][col] = GetElement(Bsub, row, col);
        // Synchronize to make sure the sub-matrices are loaded
        // before starting the computation
        __syncthreads();
        if ((blockRow == 0) && (blockCol == 0) && (row == 0) && (col == 0)) printShMatrix(As[0], BLOCK_SIZE, BLOCK_SIZE);
        // Multiply Asub and Bsub together
        for (int e = 0; e < BLOCK_SIZE; ++e)
            Cvalue += As[row][e] * Bs[e][col];
        // Synchronize to make sure that the preceding
        // computation is done before loading two new
        // sub-matrices of A and B in the next iteration
        __syncthreads();
    }
    // Write Csub to device memory
    // Each thread writes one element
    SetElement(Csub, row, col, Cvalue);
}


 __global__ void MatMulKernel_shflA(Matrix A, Matrix B, Matrix C)
{
    // Block row and column
    int blockRow = blockIdx.y;
    int blockCol = blockIdx.x;
    // Each thread block computes one sub-matrix Csub of C
    Matrix Csub = GetSubMatrix(C, blockRow, blockCol);
    // Each thread computes one element of Csub
    // by accumulating results into Cvalue
    float Cvalue = 0;
    // Thread row and column within Csub
    int row = threadIdx.y;
    int col = threadIdx.x;
    // Loop over all the sub-matrices of A and B that are
    // required to compute Csub
    // Multiply each pair of sub-matrices together
    // and accumulate the results
    for (int m = 0; m < (A.width / BLOCK_SIZE); ++m) {
        // Get sub-matrix Asub of A
        Matrix Asub = GetSubMatrix(A, blockRow, m);
        // Get sub-matrix Bsub of B
        Matrix Bsub = GetSubMatrix(B, m, blockCol);
        // Shared memory used to store Asub and Bsub respectively
        __shared__ float Bs[BLOCK_SIZE][BLOCK_SIZE];
        float my_A = GetElement(Asub, row, col);
        Bs[row][col] = GetElement(Bsub, row, col);
        // Synchronize to make sure the sub-matrices are loaded
        // before starting the computation
        __syncthreads();
        // Multiply Asub and Bsub together
        for (int e = 0; e < BLOCK_SIZE; ++e)
            Cvalue += __shfl_sync(0xFFFFFFFF, my_A, e) * Bs[e][col];
        // Synchronize to make sure that the preceding
        // computation is done before loading two new
        // sub-matrices of A and B in the next iteration
        __syncthreads();
    }
    // Write Csub to device memory
    // Each thread writes one element
    SetElement(Csub, row, col, Cvalue);
}

// Matrix multiplication - Host code
// Matrix dimensions are assumed to be multiples of BLOCK_SIZE
void MatMul(const Matrix A, const Matrix B, Matrix C)
{
    // Load A and B to device memory
    Matrix d_A;
    d_A.width = d_A.stride = A.width; d_A.height = A.height;
    size_t size = A.width * A.height * sizeof(float);
    cudaMalloc(&d_A.elements, size);
    cudaMemcpy(d_A.elements, A.elements, size,
               cudaMemcpyHostToDevice);
    Matrix d_B;
    d_B.width = d_B.stride = B.width; d_B.height = B.height;
    size = B.width * B.height * sizeof(float);
    cudaMalloc(&d_B.elements, size);
    cudaMemcpy(d_B.elements, B.elements, size,
    cudaMemcpyHostToDevice);
    Matrix h_C_shfl;
    h_C_shfl.width = h_C_shfl.stride = C.width; h_C_shfl.height = C.height;
    h_C_shfl.elements = new float[h_C_shfl.width*h_C_shfl.height];
    // Allocate C in device memory
    Matrix d_C;
    d_C.width = d_C.stride = C.width; d_C.height = C.height;
    size = C.width * C.height * sizeof(float);
    cudaMalloc(&d_C.elements, size);
    // Invoke kernel
    dim3 dimBlock(BLOCK_SIZE, BLOCK_SIZE);
    dim3 dimGrid(B.width / dimBlock.x, A.height / dimBlock.y);
    MatMulKernel<<<dimGrid, dimBlock>>>(d_A, d_B, d_C); // warm-up
    cudaDeviceSynchronize();
    unsigned long long dt = dtime_usec(0);
    MatMulKernel<<<dimGrid, dimBlock>>>(d_A, d_B, d_C);
    cudaDeviceSynchronize();
    dt = dtime_usec(dt);
    std::cout << "shared kernel time:  " << dt << "us" << std::endl;
    cudaMemcpy(C.elements, d_C.elements, size, cudaMemcpyDeviceToHost);
#if 0
    MatMulKernel_shflA<<<dimGrid, dimBlock>>>(d_A, d_B, d_C); // warm-up
    cudaDeviceSynchronize();
    dt = dtime_usec(0);
    MatMulKernel_shflA<<<dimGrid, dimBlock>>>(d_A, d_B, d_C);
    cudaDeviceSynchronize();
    dt = dtime_usec(dt);
    std::cout << "shuffle kernel time: " << dt << "us" << std::endl;
    // Read C from device memory
    cudaMemcpy(h_C_shfl.elements, d_C.elements, size, cudaMemcpyDeviceToHost);
    for (int i = 0; i < h_C_shfl.width*h_C_shfl.height; i++) if (h_C_shfl.elements[i] != C.elements[i]) {std::cout << "mismatch at: " << i << " should be: " << C.elements[i] << " was: " << h_C_shfl.elements[i] << std::endl; return;}
#endif
    // Free device memory
    cudaFree(d_A.elements);
    cudaFree(d_B.elements);
    cudaFree(d_C.elements);
}

int main(int argc, char *argv[]){

  Matrix h_A, h_B, h_C;
  int dim = 4*BLOCK_SIZE;
  if (argc > 1) dim = atoi(argv[1])*BLOCK_SIZE;
  h_A.width=h_A.stride=h_A.height=h_B.width=h_B.stride=h_B.height=h_C.width=h_C.stride=h_C.height=dim;
  h_A.elements = new float[dim*dim];
  h_B.elements = new float[dim*dim];
  h_C.elements = new float[dim*dim];
  for (int i = 0; i < dim*dim; i++) {
    h_A.elements[i] = rand()%RNG;
    h_B.elements[i] = rand()%RNG;}
  MatMul(h_A, h_B, h_C);
}
# nvcc -o t15 t15.cu
# ./t15
1.00 0.00 2.00 1.00
0.00 2.00 1.00 1.00
2.00 1.00 1.00 0.00
1.00 0.00 0.00 2.00

0.00 2.00 2.00 2.00
0.00 0.00 2.00 2.00
0.00 2.00 2.00 0.00
2.00 2.00 0.00 1.00

0.00 1.00 2.00 0.00
2.00 2.00 1.00 2.00
1.00 0.00 2.00 0.00
1.00 2.00 0.00 1.00

2.00 2.00 1.00 2.00
2.00 1.00 0.00 2.00
2.00 0.00 2.00 1.00
1.00 0.00 0.00 0.00

1.00 0.00 2.00 1.00
0.00 2.00 1.00 1.00
2.00 1.00 1.00 0.00
1.00 0.00 0.00 2.00

0.00 2.00 2.00 2.00
0.00 0.00 2.00 2.00
0.00 2.00 2.00 0.00
2.00 2.00 0.00 1.00

0.00 1.00 2.00 0.00
2.00 2.00 1.00 2.00
1.00 0.00 2.00 0.00
1.00 2.00 0.00 1.00

2.00 2.00 1.00 2.00
2.00 1.00 0.00 2.00
2.00 0.00 2.00 1.00
1.00 0.00 0.00 0.00

shared kernel time:  1557us
#

yes, i did.