Hi you guys. I was wondering whether my WMMA code here follows good practice for loading matrices.
wmma::fragment<wmma::matrix_a, OUTER_WIDTH, INNER_DIM, OUTER_HEIGHT, half, wmma::row_major> mat_ne_frag;
wmma::fragment<wmma::matrix_a, OUTER_WIDTH, INNER_DIM, OUTER_HEIGHT, half, wmma::row_major> mat_nw_frag;
wmma::fragment<wmma::matrix_a, OUTER_WIDTH, INNER_DIM, OUTER_HEIGHT, half, wmma::row_major> mat_se_frag;
wmma::fragment<wmma::matrix_a, OUTER_WIDTH, INNER_DIM, OUTER_HEIGHT, half, wmma::row_major> mat_sw_frag;
wmma::fragment<wmma::matrix_b, OUTER_WIDTH, INNER_DIM, OUTER_HEIGHT, half, wmma::col_major> vec_up_frag;
wmma::fragment<wmma::matrix_b, OUTER_WIDTH, INNER_DIM, OUTER_HEIGHT, half, wmma::col_major> vec_down_frag;
wmma::fragment<wmma::accumulator, OUTER_WIDTH, INNER_DIM, OUTER_HEIGHT, half> prod_frag;
wmma::load_matrix_sync(vec_up_frag, permutation_vectors, PERMUTATION_LENGTH);
wmma::load_matrix_sync(vec_down_frag, permutation_vectors + NEXT_BLOCK, PERMUTATION_LENGTH);
wmma::load_matrix_sync(mat_nw_frag, permutation_matrix, PERMUTATION_LENGTH);
wmma::load_matrix_sync(mat_ne_frag, permutation_matrix + NEXT_BLOCK, PERMUTATION_LENGTH);
wmma::load_matrix_sync(mat_sw_frag, permutation_matrix + LOWER_ROW, PERMUTATION_LENGTH);
wmma::load_matrix_sync(mat_se_frag, permutation_matrix + LOWER_ROW + NEXT_BLOCK, PERMUTATION_LENGTH);
wmma::fill_fragment(prod_frag, 0.0f);
wmma::mma_sync(prod_frag, mat_nw_frag, vec_up_frag, prod_frag);
wmma::mma_sync(prod_frag, mat_ne_frag, vec_down_frag, prod_frag);
wmma::store_matrix_sync(permutation_vectors, prod_frag, PERMUTATION_LENGTH, wmma::mem_col_major);
wmma::fill_fragment(prod_frag, 0.0f);
wmma::mma_sync(prod_frag, mat_sw_frag, vec_up_frag, prod_frag);
wmma::mma_sync(prod_frag, mat_se_frag, vec_down_frag, prod_frag);
wmma::store_matrix_sync(permutation_vectors + NEXT_BLOCK, prod_frag, PERMUTATION_LENGTH, wmma::mem_col_major);