Hi, the code is below, before calling mma instruction, it shows 0 errors in compute-sanitizer.
__global__
void tf32_computeX32(
const BIT_TYPE* __restrict__ d_rowWndOffsetBit,
const BIT_TYPE* __restrict__ d_tcLocalBit,
const vint* __restrict__ d_sparseA2B,
const MAT_VAL_TYPE* __restrict__ d_valueA,
const MAT_PTR_TYPE* __restrict__ d_block2Idx,
const vint* __restrict__ d_data2Idx,
const MAT_VAL_TYPE* __restrict__ d_MatB,
MAT_VAL_TYPE* d_MatC,
const vint numNodes,
const vint feature_dim
) {
using ARegisters = MAT_VAL_TYPE[4];
using BRegisters = MAT_VAL_TYPE[2];
using CRegisters = MAT_VAL_TYPE[4];
ARegisters fragA;
BRegisters fragB;
CRegisters fragC = {0.0};
const vint bid = blockIdx.x;
const vint laneid = 31 & threadIdx.x;
const vint warpSize = blockDim.x;
const vint threadPerBlk = blockDim.x * blockDim.y;
const vint dimTileNum = feature_dim / COL_WINDOW;
const vint tid = threadIdx.y * warpSize + laneid; // local thread ID
// const vint global_warpID = (global_tid >> 5) * WARP_NUM;
const vint local_warpID = threadIdx.y;
int groupID = laneid >> 2;
int tID_in_group = 3 & laneid;
// load A
vint row02 = groupID;
vint row13 = groupID + 8;
vint col01 = tID_in_group;
vint col23 = tID_in_group + 4;
vint row_b0 = tID_in_group;
vint row_b1 = tID_in_group + 4;
vint col_b = groupID;
vint row_c01 = groupID;
vint row_c23 = groupID + 8;
vint dense_rowIdx_off = laneid & 3;
vint denseDimIdx = (laneid >> 2) + local_warpID * COL_WINDOW;
vint denseDimIdx1 = denseDimIdx + 8;
constexpr const int inst_m = 16;
constexpr const int inst_k = 8;
constexpr const int inst_n = 8;
constexpr const int mat_len = inst_m * inst_k;
constexpr const int idx_len = inst_n;
__shared__ MAT_VAL_TYPE d_sharedSparseA[2 * mat_len];
__shared__ vint d_sharedSparseA2B[2 * idx_len];
MAT_VAL_TYPE d_denseB[inst_m * inst_n];
int saPtr = __cvta_generic_to_shared(d_sharedSparseA);
int siPtr = __cvta_generic_to_shared(d_sharedSparseA2B);
MAT_PTR_TYPE start_blk_idx = d_block2Idx[bid];
MAT_PTR_TYPE end_blk_idx = d_block2Idx[bid+1];
vint start_data_idx = d_data2Idx[start_blk_idx];
vint end_data_idx = d_data2Idx[start_blk_idx+1];
const vint denseBound = numNodes * feature_dim;
#pragma unroll
for(vint i = tid; i < mat_len; i += threadPerBlk) {
d_sharedSparseA[i] = (d_tcLocalBit[i + inst_k * inst_m * start_blk_idx] == false) ?
0.0 : d_valueA[start_data_idx +
get_data_idx_offset(
inst_k * inst_m * start_blk_idx, i + inst_k * inst_m * start_blk_idx, d_tcLocalBit)];
}
__syncthreads();
if(tid < inst_k) {
d_sharedSparseA2B[tid] = d_sparseA2B[start_blk_idx * inst_m + tid];
}
__syncthreads();
for(vint tc_block = start_blk_idx + 1; tc_block < end_blk_idx; ++tc_block) {
vint start_idx = tc_block * ROW_WINDOW * COL_WINDOW;
vint start_col_idx = tc_block * COL_WINDOW;
vint shared_mem_sel = ((tc_block - start_blk_idx) & 1) ^ 1;
vint shared_mem_sel_next = ((tc_block - start_blk_idx - 1) & 1) ^ 1;
if(local_warpID < dimTileNum) {
vint dense_rowIdx0 = d_sharedSparseA2B[(shared_mem_sel << 3) + dense_rowIdx_off];
vint dense_rowIdx1 = d_sharedSparseA2B[(shared_mem_sel << 3) + dense_rowIdx_off + 4];
if(dense_rowIdx0 > numNodes) {
fragB[0] = 1.0;
}
else {
vint sourceIdx0 = dense_rowIdx0 * feature_dim + denseDimIdx;
fragB[0] = d_MatB[sourceIdx0];
}
if(dense_rowIdx1 > numNodes) {
fragB[1] = 0.0;
} else {
vint sourceIdx1 = dense_rowIdx1 * feature_dim + denseDimIdx;
// fragB[1] = load_fp32_from_global(d_MatB + sourceIdx1);
fragB[1] = d_MatB[sourceIdx1];
}
vint start_data_idx_inner = d_data2Idx[tc_block];
vint end_data_idx_inner = d_data2Idx[tc_block+1];
#pragma unroll
for(vint i = tid; i < mat_len; i += threadPerBlk) {
MAT_VAL_TYPE val = (d_tcLocalBit[i + inst_k * inst_m * tc_block] == false) ?
0.0 : d_valueA[start_data_idx_inner + get_data_idx_offset(
inst_k * inst_m * tc_block, i + inst_k * inst_m * tc_block, d_tcLocalBit)];
d_sharedSparseA[i + (shared_mem_sel_next << 7)] = val;
}
if(tid < inst_m) {
vint val = d_sparseA2B[tc_block * inst_m + tid];
d_sharedSparseA2B[(shared_mem_sel_next << 3) + tid] = val;
}
fragA[0] = d_sharedSparseA[(shared_mem_sel << 7) + row02 * inst_k + col01];
fragA[1] = d_sharedSparseA[(shared_mem_sel << 7) + row13 * inst_k + col01];
fragA[2] = d_sharedSparseA[(shared_mem_sel << 7) + row02 * inst_k + col23];
fragA[3] = d_sharedSparseA[(shared_mem_sel << 7) + row13 * inst_k + col23];
vint const* A = reinterpret_cast<vint const*>(fragA);
vint const* B = reinterpret_cast<vint const*>(fragB);
float *C = reinterpret_cast<float*>(fragC);
asm(
"cvt.rna.tf32.f32 %4, %4;\n"
"cvt.rna.tf32.f32 %5, %5;\n"
"cvt.rna.tf32.f32 %6, %6;\n"
"cvt.rna.tf32.f32 %7, %7;\n"
"cvt.rna.tf32.f32 %8, %8;\n"
"cvt.rna.tf32.f32 %9, %9;\n"
"mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32"
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%0, %1, %2, %3};"
:"+f"(C[0]), "+f"(C[1]), "+f"(C[2]), "+f"(C[3]) // output
:"r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
"r"(B[0]), "r"(B[1])
);
}
}
}