How to use slicedK in GEMM?

Hi! I have written a code for slicedK in GEMM, but it seems very slow…I tried to understand cutlass’s slicedK, but can not understand it…So I post my code here and explain my concept, hoping someone can give me some suggestions…Thank you!!!

So the basic idea comes from cutlass’s post. My allocation is: 64 threads (2 warps) to calculate 64 * 64 size. each thread read 4 values from A and 4 values from B. And calculate 64 results. We arrange 256 threads(8 waprs) in total, so this is sliced-4. 4 groups of warps calculate a same 64 * 64 location, but the K is split, so the final results need to be reduced.

In shared memory, I insert 1/4 of the 64 results in each threads, which is 32 * 32 one time, (take up 16KB), and repeat this 4 times. Each time we add the corresponding location’s value and get 4KB’s result, write into global memory.

The nsight compute shows:

Thank you!!!

#include<iostream>
using namespace std;
#include <cstdint>
#include <cstdlib>
#include <cstdio>
#include <cmath>
#include <vector>

#define FETCH_FLOAT4(pointer) (reinterpret_cast<float4*>(&(pointer))[0])


bool check(const float *A,
	const float *B,
	const float *C,
	int m, int n, int k) {
	for (int i = 0; i < m; ++i) {
		for (int j = 0; j < n; ++j) {
			float sum = 0.f;
			for (int p = 0; p < k; ++p) {
				sum += A[i * k + p] * B[j + p * n];
			}

			if (std::fabs(sum - C[i * n + j]) / std::fabs(sum) > 1e-5f) {
				printf("C[%d][%d] not match, %f vs %f\n", i, j, sum, C[i * n + j]);
				return false;
			}
		}
	}

	return true;
}


__device__ __forceinline__
uint32_t smem_u32addr(const void *smem_ptr) {
	uint32_t addr;
	asm("{.reg .u64 u64addr;\n"
		" cvta.to.shared.u64 u64addr, %1;\n"
		" cvt.u32.u64 %0, u64addr;}\n"
		: "=r"(addr)
		: "l"(smem_ptr)
	);

	return addr;
}

__device__ __forceinline__
void ldg32_nc(float &reg, const void *ptr, bool guard) {
	asm volatile (
		"{.reg .pred p;\n"
		" setp.ne.b32 p, %2, 0;\n"
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDACC_VER_MINOR__ >= 4 && \
    __CUDA_ARCH__ >= 750
		" @p ld.global.nc.L2::128B.f32 %0, [%1];}\n"
#else
		" @p ld.global.nc.f32 %0, [%1];}\n"
#endif
		: "=f"(reg)
		: "l"(ptr), "r"((int)guard)
		);
}


__device__ __forceinline__ void ldg32_nc_0(float &reg, const void *ptr) {
	asm volatile("{mov.b32 %0, 0;\n"
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDACC_VER_MINOR__ >= 4 &&                 \
    __CUDA_ARCH__ >= 750
		"ld.global.nc.L2::128B.f32 %0, [%1];}\n"
#else
		"ld.global.nc.f32 %0, [%1];}\n"
#endif
		: "=f"(reg)
		: "l"(ptr));
}


__device__ __forceinline__
void stg32(const float &reg, void *ptr, bool guard) {
	asm volatile (
		"{.reg .pred p;\n"
		" setp.ne.b32 p, %2, 0;\n"
		" @p st.global.f32 [%0], %1;}\n"
		: : "l"(ptr), "f"(reg), "r"((int)guard)
		);
}

__device__ __forceinline__
void lds128(float &reg0, float &reg1,
	float &reg2, float &reg3,
	const uint32_t &addr) {
	asm volatile (
		"ld.shared.v4.f32 {%0, %1, %2, %3}, [%4];\n"
		: "=f"(reg0), "=f"(reg1), "=f"(reg2), "=f"(reg3)
		: "r"(addr)
		);
}

__device__ __forceinline__
void sts32(const float &reg, const uint32_t &addr) {
	asm volatile (
		"st.shared.f32 [%0], %1;\n"
		: : "r"(addr), "f"(reg)
		);
}

__device__ __forceinline__
void sts128(const float &reg0, const float &reg1,
	const float &reg2, const float &reg3,
	const uint32_t &addr) {
	asm volatile (
		"st.shared.v4.f32 [%0], {%1, %2, %3, %4};\n"
		: : "r"(addr), "f"(reg0), "f"(reg1), "f"(reg2), "f"(reg3)
		);
}


__device__ __forceinline__
void sts64(const float &reg0, const float &reg1,
	const uint32_t &addr) {
	asm volatile (
		"st.shared.v2.f32 [%0], {%1, %2};\n"
		: : "r"(addr), "f"(reg0), "f"(reg1)
		);
}


struct StgFrag {
	float data[4][4];

	__device__ __forceinline__
		StgFrag(const float(&C_frag)[8][8], int tile_x, int tile_y) {
#pragma unroll
		for (int i = 0; i < 4; ++i) {
#pragma unroll
			for (int j = 0; j < 4; ++j) {
				data[i][j] = C_frag[tile_y * 4 + i][tile_x * 4 + j];
			}
		}
	}
};

__device__ __noinline__
void C_tile_wb(StgFrag C_frag,
	float *C_stg_ptr,
	const float *C_lds_ptr,
	uint32_t C_sts_addr,
	uint32_t m,
	uint32_t n,
	uint32_t m_idx,
	uint32_t n_idx) {
	__syncthreads();

#pragma unroll
	for (int i = 0; i < 4; ++i) {
		sts128(C_frag.data[i][0],
			C_frag.data[i][1],
			C_frag.data[i][2],
			C_frag.data[i][3],
			C_sts_addr + i * 8 * sizeof(float4));
	}

	__syncthreads();

	uint32_t m_guard = m < m_idx ? 0 : m - m_idx;

#pragma unroll
	for (int i = 0; i < 16; ++i) {
		stg32(C_lds_ptr[i * 32],
			C_stg_ptr + i * n,
			i < m_guard && n_idx < n);
	}
}


__device__ __forceinline__ void stg128(const float &reg0, const float &reg1,
	const float &reg2, const float &reg3,
	const float *addr) {
	asm volatile("st.global.v4.f32 [%0], {%1, %2, %3, %4};\n"
		:
	: "l"(addr), "f"(reg0), "f"(reg1), "f"(reg2), "f"(reg3));
}

__global__ __launch_bounds__(256, 2) void sgemm_128x128x8(uint32_t m,
	uint32_t n,
	uint32_t k,
	float *A,
	float *B,
	float *C) {

	__shared__ __align__(8 * 1024) char smem[20 * 1024];  // 16.5KB
	float *A_smem = reinterpret_cast<float *>(smem);
	float *B_smem = reinterpret_cast<float *>(smem + 8704);  // 8.5KB

	// A, B and C register fragment
	float A_frag[2][8];
	float B_frag[2][8];
	float C_frag[8][8];
#pragma unroll
	for (int i = 0; i < 8; ++i) {
#pragma unroll
		for (int j = 0; j < 8; ++j) {
			C_frag[i][j] = 0;
		}
	}

	const uint32_t lane_id = threadIdx.x % 32;
	const uint32_t warp_id = threadIdx.x / 32;

	// 4x8 threads each warp for FFMA
	const uint32_t mma_tid_x = (lane_id / 2) % 8;
	const uint32_t mma_tid_y = (lane_id / 16) * 2 + (lane_id % 2);

	// A_tile & B_tile ldg pointer
	int from_a = (blockIdx.y * 64 + (threadIdx.x % 64) / 4 * 4) * k + (threadIdx.x % 64) % 4 + (threadIdx.x / 64) * (k / 4);
	int from_b = ((threadIdx.x % 64) / 16 + (threadIdx.x / 64) * (k / 4)) * n + blockIdx.x * 64 + (threadIdx.x % 64) % 16 * 4;
		// A_tile & B_tile sts/lds pointer
		// using uint32_t pointer for faster double buffer switch

	uint32_t A_lds_addr = smem_u32addr(
		A_smem + (warp_id % 2) * 32 + mma_tid_y * 4 + (threadIdx.x / 64) * 68 * 4);
	uint32_t B_lds_addr = smem_u32addr(
		B_smem + mma_tid_x * 4 + (threadIdx.x / 64) * 64 * 4);

	float4 b_ldg_reg;
	float a_ldg_reg[4];


	uint32_t a_sts_addr = smem_u32addr(A_smem + ((threadIdx.x % 64) % 4) * 68 + ((threadIdx.x % 64) / 4) * 4 + (threadIdx.x / 64) * 68 * 4);
	uint32_t b_sts_addr = smem_u32addr(B_smem + ((threadIdx.x % 64) / 16) * 64 + ((threadIdx.x % 64) % 16) * 4 + (threadIdx.x / 64) * 64 * 4);
	// 1'st A&B tile loaded before the k_tile loop

	uint32_t k_tiles = (k / 4 + 3) / 4 - 1;
	uint32_t first_k_tile = k / 4 - k_tiles * 4 + (k / 4)*(threadIdx.x / 64);
	// load 1'st tile to shared memory
	{
		// load first
		// load gmem to smem for ashare
#pragma unroll
		for (int i = 0; i < 4; ++i) {
			if ((threadIdx.x % 64) % 4 + (threadIdx.x / 64)*(k / 4) < first_k_tile  &&  blockIdx.y * 64 + (threadIdx.x % 64) / 4 * 4 + i < m) {
				ldg32_nc_0(a_ldg_reg[i], (const char *)(A + from_a) + i * k * sizeof(float));
			}
			else {
				a_ldg_reg[i] = 0;
			}
		}
		sts128(a_ldg_reg[0], a_ldg_reg[1], a_ldg_reg[2], a_ldg_reg[3], a_sts_addr);

		// load gmem to smem for bshare


		if (from_b < (1 + (threadIdx.x % 64) / 16 + (threadIdx.x / 64)*(k / 4))*n && (threadIdx.x % 64) / 16 + (threadIdx.x / 64)*(k / 4) < first_k_tile) {
			b_ldg_reg = FETCH_FLOAT4(B[from_b]);
		}
		else {
			b_ldg_reg = float4{ 0, 0, 0, 0 };
		}

		FETCH_FLOAT4(B_smem[((threadIdx.x % 64) / 16) * 64 + ((threadIdx.x % 64) % 16) * 4 + (threadIdx.x / 64) * 64 * 4]) = b_ldg_reg;
		__syncthreads();
		// add offset and flip flag
		from_a += k / 4 - k_tiles * 4;
		from_b += (k / 4 - k_tiles * 4) * n;
		a_sts_addr += 68 * 4 * 4 * sizeof(float);
		b_sts_addr += 64 * 4 * 4 * sizeof(float);
	}



		// load 1'st fragment
	lds128(A_frag[0][0], A_frag[0][1], A_frag[0][2], A_frag[0][3],
		A_lds_addr);
	lds128(A_frag[0][4], A_frag[0][5], A_frag[0][6], A_frag[0][7],
		A_lds_addr + 16 * sizeof(float));
	lds128(B_frag[0][0], B_frag[0][1], B_frag[0][2], B_frag[0][3],
		B_lds_addr);
	lds128(B_frag[0][4], B_frag[0][5], B_frag[0][6], B_frag[0][7],
		B_lds_addr + 32 * sizeof(float));



	int jump = 0;
	// k_tiles loop
	for (; k_tiles > 0; --k_tiles) {
		jump ^= 1;
#pragma unroll
		for (int k_frag = 0; k_frag < 4; ++k_frag) {
			// store next A&B tile to shared memory
			if (k_frag == 3) {
				sts128(a_ldg_reg[0], a_ldg_reg[1], a_ldg_reg[2], a_ldg_reg[3], a_sts_addr);
				sts128(b_ldg_reg.x, b_ldg_reg.y, b_ldg_reg.z, b_ldg_reg.w, b_sts_addr);
				__syncthreads();

				// switch double buffer
				if (jump == 1) {
					A_lds_addr += 68 * 4 * 4 * sizeof(float);
					B_lds_addr += 64 * 4 * 4 * sizeof(float);
					a_sts_addr -= 68 * 4 * 4 * sizeof(float);
					b_sts_addr -= 64 * 4 * 4 * sizeof(float);
				}
				else {
					A_lds_addr -= 68 * 4 * 4 * sizeof(float);
					B_lds_addr -= 64 * 4 * 4 * sizeof(float);
					a_sts_addr += 68 * 4 * 4 * sizeof(float);
					b_sts_addr += 64 * 4 * 4 * sizeof(float);
				}
				// ldg pointer for next tile
				from_a += 4;
				from_b += 4 * n;
			}

			// load next A&B fragment from shared memory to register
			lds128(A_frag[(k_frag + 1) % 2][0],
				A_frag[(k_frag + 1) % 2][1],
				A_frag[(k_frag + 1) % 2][2],
				A_frag[(k_frag + 1) % 2][3],
				A_lds_addr + (k_frag + 1) % 4 * 68 * sizeof(float));
			lds128(A_frag[(k_frag + 1) % 2][4],
				A_frag[(k_frag + 1) % 2][5],
				A_frag[(k_frag + 1) % 2][6],
				A_frag[(k_frag + 1) % 2][7],
				A_lds_addr + ((k_frag + 1) % 4 * 68 + 16) * sizeof(float));
			lds128(B_frag[(k_frag + 1) % 2][0],
				B_frag[(k_frag + 1) % 2][1],
				B_frag[(k_frag + 1) % 2][2],
				B_frag[(k_frag + 1) % 2][3],
				B_lds_addr + (k_frag + 1) % 4 * 64 * sizeof(float));
			lds128(B_frag[(k_frag + 1) % 2][4],
				B_frag[(k_frag + 1) % 2][5],
				B_frag[(k_frag + 1) % 2][6],
				B_frag[(k_frag + 1) % 2][7],
				B_lds_addr + ((k_frag + 1) % 4 * 64 + 32) * sizeof(float));
									// load next A&B tile
			if (k_frag == 0) {
				if (from_b < (1 + (threadIdx.x % 64) / 16 + (threadIdx.x / 64)*(k / 4))*n + (-k_tiles * 4 + k / 4)*n && (-k_tiles * 4 + k / 4) + (threadIdx.x % 64) / 16 + (threadIdx.x / 64)*(k / 4) < (threadIdx.x / 64 + 1)*(k / 4)) {

					b_ldg_reg = FETCH_FLOAT4(B[from_b]);
				}
				else {
					b_ldg_reg = float4{ 0, 0, 0, 0 };
				}

#pragma unroll
				for (int i = 0; i < 4; ++i) {
					if ((threadIdx.x % 64) % 4 + (threadIdx.x / 64)*(k / 4) + (-k_tiles * 4 + k / 4) < k && blockIdx.y * 64 + (threadIdx.x % 64) / 4 * 4 + i < m) {
						ldg32_nc_0(a_ldg_reg[i], (const char *)(A + from_a) + i * k * sizeof(float));
					}
					else {
						a_ldg_reg[i] = 0;
					}
				}
			}

			// FFMA loop
#pragma unroll
			for (int i = 0; i < 8; ++i) {
#pragma unroll
				for (int j = 0; j < 8; ++j) {
					C_frag[i][j] += A_frag[k_frag % 2][i] *
						B_frag[k_frag % 2][j];
				}
			}
		}
	}

	// FFMA for the last tile
#pragma unroll
	for (int k_frag = 0; k_frag < 4; ++k_frag) {
		if (k_frag < 3) {
			// load next A&B fragment from shared memory to register
			lds128(A_frag[(k_frag + 1) % 2][0],
				A_frag[(k_frag + 1) % 2][1],
				A_frag[(k_frag + 1) % 2][2],
				A_frag[(k_frag + 1) % 2][3],
				A_lds_addr + (k_frag + 1) % 4 * 68 * sizeof(float));
			lds128(A_frag[(k_frag + 1) % 2][4],
				A_frag[(k_frag + 1) % 2][5],
				A_frag[(k_frag + 1) % 2][6],
				A_frag[(k_frag + 1) % 2][7],
				A_lds_addr + ((k_frag + 1) % 4 * 68 + 16) * sizeof(float));
			lds128(B_frag[(k_frag + 1) % 2][0],
				B_frag[(k_frag + 1) % 2][1],
				B_frag[(k_frag + 1) % 2][2],
				B_frag[(k_frag + 1) % 2][3],
				B_lds_addr + (k_frag + 1) % 4 * 64 * sizeof(float));
			lds128(B_frag[(k_frag + 1) % 2][4],
				B_frag[(k_frag + 1) % 2][5],
				B_frag[(k_frag + 1) % 2][6],
				B_frag[(k_frag + 1) % 2][7],
				B_lds_addr + ((k_frag + 1) % 4 * 64 + 32) * sizeof(float));
		}

						// FFMA loop
#pragma unroll
		for (int i = 0; i < 8; ++i) {
#pragma unroll
			for (int j = 0; j < 8; ++j) {
				C_frag[i][j] += A_frag[k_frag % 2][i] *
					B_frag[k_frag % 2][j];
			}
		}
	}


			// C_tile write back, reuse A&B tile shared memory buffer
	uint32_t C_sts_addr = smem_u32addr((float4 *)(smem + warp_id * 2048) +
		mma_tid_y * 4 * 8 + mma_tid_x);
	uint32_t C_lds_ptr = smem_u32addr(A_smem + (mma_tid_y * 4 * 8 + mma_tid_x) * 4 + (warp_id % 2) * 16 * 32);
	uint32_t C_lds_addr = smem_u32addr(A_smem + threadIdx.x / 8 * 32 + (threadIdx.x % 8) * 4);



	uint32_t m_idx = blockIdx.y * 64;

	if (m_idx >= m) {
		return;
	}
	else if (m_idx + 32 <= m) {

#pragma unroll
		for (int i = 0; i < 2; ++i) {


			for (int j = 0; j < 2; ++j) {
				__syncthreads();

#pragma unroll
				for (int p = 0; p < 4; ++p) {
					sts128(C_frag[i * 4 + p][j * 4],
						C_frag[i * 4 + p][j * 4 + 1],
						C_frag[i * 4 + p][j * 4 + 2],
						C_frag[i * 4 + p][j * 4 + 3],
						C_sts_addr + p * 8 * sizeof(float4));
				}

				__syncthreads();


				lds128(B_frag[0][0], B_frag[0][1], B_frag[0][2], B_frag[0][3], C_lds_addr);
				lds128(B_frag[0][4], B_frag[0][5], B_frag[0][6], B_frag[0][7],
					C_lds_addr + (32 * 32) * sizeof(float));
				lds128(B_frag[1][0], B_frag[1][1], B_frag[1][2], B_frag[1][3],
					C_lds_addr + (32 * 32 * 2) * sizeof(float));
				lds128(B_frag[1][4], B_frag[1][5], B_frag[1][6], B_frag[1][7],
					C_lds_addr + (32 * 32 * 3) * sizeof(float));


				B_frag[0][0] += (B_frag[0][4] + B_frag[1][0] + B_frag[1][4]);
				B_frag[0][1] += (B_frag[0][5] + B_frag[1][1] + B_frag[1][5]);
				B_frag[0][2] += (B_frag[0][6] + B_frag[1][2] + B_frag[1][6]);
				B_frag[0][3] += (B_frag[0][7] + B_frag[1][3] + B_frag[1][7]);

				if (blockIdx.y * 64 + (threadIdx.x % 128) / 8 + i * 16 + (threadIdx.x / 128) * 32 < m&& blockIdx.x * 64 + (threadIdx.x % 8) * 4 + j * 32 < n) {
					stg128(B_frag[0][0], B_frag[0][1], B_frag[0][2], B_frag[0][3], C + (blockIdx.y * 64 + (threadIdx.x % 128) / 8 + i * 16 + (threadIdx.x / 128) * 32)*n + blockIdx.x * 64 + (threadIdx.x % 8) * 4 + j * 32);
				}
			}
		}
	}
	//	else {
	//#pragma unroll
	//		for (int i = 0; i < 2; ++i) {
	//#pragma unroll
	//			for (int j = 0; j < 2; ++j) {
	//				StgFrag stg_frag(C_frag, j, i);

	//				C_tile_wb(stg_frag,
	//					C_stg_ptr + i * 16 * n + j * 32,
	//					C_lds_ptr,
	//					C_sts_addr,
	//					m,
	//					n,
	//					m_idx + i * 16,
	//					n_idx + j * 32);
	//			}
	//		}
	//	}
}






float* random_matrix(int row, int col) {
	float* mat = new float[row * col];


	for (int i = 0; i < row; ++i) {
		for (int j = 0; j < col; ++j) {
			if (i * col + j + 1 < 10) {
				mat[i * col + j] = i * col + j + 1;
			}
			else {
				mat[i * col + j] = 0.5;
			}
		}
	}

	return mat;
}

// float* random_matrix(int row, int col) {
// 	float* mat = new float[row * col];
//
//
// 	for (int i = 0; i < row; ++i) {
// 		for (int j = 0; j < col; ++j) {
// 			mat[i * col + j] = 0;
// 		}
// 	}
//
// 	return mat;
// }


void print_mat(float* mat, int row, int col) {
	/*Display the matrix for visualizatoin*/
	for (int i = 0; i < row; ++i) {
		for (int j = 0; j < col; ++j) {
			cout << mat[i * col + j] << " ";
		}cout << endl;
	}
	cout << "\n" << endl;
}



int main()
{
	const int m = 3072, k = 3072, n = 64;
	float* a = random_matrix(m, k);
	float* b = random_matrix(k, n);
	float* c = new float[m*n];

	float* dev_a, *dev_b, *dev_c;

	cudaMalloc((void**)&dev_a, m * k * sizeof(float));
	cudaMalloc((void**)&dev_b, k * n * sizeof(float));
	cudaMalloc((void**)&dev_c, m * n * sizeof(float));

	cudaMemcpy(dev_a, a, m * k * sizeof(float), cudaMemcpyHostToDevice);
	cudaMemcpy(dev_b, b, k * n * sizeof(float), cudaMemcpyHostToDevice);


	constexpr int BLOCK = 64;
	dim3 grid((n + BLOCK - 1) / BLOCK, (m + 64 - 1) / 64);
	int repeat = 1;


	cudaEvent_t start, stop;
	cudaEventCreate(&start);
	cudaEventCreate(&stop);
	cudaEventRecord(start);
	cudaEventQuery(start);

	for (int i = 0; i < repeat; i++) {
		sgemm_128x128x8 << <grid, 256 >> > (m, n, k, dev_a, dev_b, dev_c);
	}


	cudaEventRecord(stop);
	cudaEventSynchronize(stop);
	float elapsed_time;
	cudaEventElapsedTime(&elapsed_time, start, stop);
	printf("Time = %g ms .\n", elapsed_time / repeat);
	cudaEventDestroy(start);
	cudaEventDestroy(stop);


	cudaMemcpy(c, dev_c, m * n * sizeof(float), cudaMemcpyDeviceToHost);
	bool chk = check(a, b, c, m, n, k);
	printf("Matrix_C check: %s\n", chk ? "OK" : "Failed");
	//cout << 'a' << endl;
	//print_mat(a, m, k);
	//cout << 'b' << endl;
	//print_mat(b, k, n);
	//cout << 'c' << endl;
	//print_mat(c, m, n);
}

From your profiler output, you are using ~25% of the available compute throughput/capability, and ~25% of the available memory throughput/capability.

In these situations, a common problem is lack of latency hiding. The GPU spends much of its time with stalled warps, waiting for those stalls to clear.

In fact, your profiler output indicates this via a “rules”/warning block:

Latency Issue The kernel exhibits low …

Your kernel launch here is using 124 registers/thread, which means your maximum achievable occupancy is 50% of full or peak occupancy. The GPU likes to have as many threads as possible to hide latency, so this might be one area to look at.

Furthermore the warning itself gives suggestions about where to look to help localize the biggest contributors to exposed latency.

To get additional familiarity with nsight compute, this blog series may be of interest.

FWIW, a sgemm kernel should be mostly compute bound, which means that the compute bar in the nsight compute output should be in the 60-90% range. The above blog shows an example of this (in part 3) profiling a CUBLAS sgemm kernel. So getting out of this latency-bound hole should be possible.

This topic was automatically closed 14 days after the last reply. New replies are no longer allowed.