Hi all,
I’m trying to learn how TMA works, so I wrote a simple demo following this blog (link).
In my tma_load_kernel
function, I encountered an error when calling:
copy(tma_load.with(tma_load_mbar),
tma_load_per_cta.partition_S(gmem_tensor_coord_cta),
tma_load_per_cta.partition_D(smem_tensor));
in my file:
#include <bits/stdc++.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime_api.h>
#include <cusparse.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/stat.h>
#include <time.h>
#include <cute/tensor.hpp>
#define CHECK_CUDA2(func) \
{ \
cudaError_t status = (func); \
if (status != cudaSuccess) { \
printf("CUDA API failed at line %d with error: %s (%d)\n", \
__LINE__, cudaGetErrorString(status), status); \
} \
}
using namespace std;
using namespace cute;
template <typename T, int CTA_M, int CTA_N, class TmaLoad, class GmemTensor>
__global__ void tma_load_kernel(__grid_constant__ const TmaLoad tma_load, GmemTensor gmem_tensor) {
using namespace cute;
constexpr int tma_transaction_bytes = CTA_M * CTA_N * sizeof(T);
__shared__ T smem_data[CTA_M * CTA_N];
__shared__ uint64_t tma_load_mbar;
auto smem_layout = make_layout(make_shape(CTA_M, CTA_N), LayoutRight{});
auto smem_tensor = make_tensor(make_smem_ptr(smem_data), smem_layout);
if (threadIdx.x == 0) {
auto gmem_tensor_coord = tma_load.get_tma_tensor(shape(gmem_tensor));
auto gmem_tensor_coord_cta = local_tile(
gmem_tensor_coord,
Tile<Int<CTA_M>, Int<CTA_N>>{},
make_coord(blockIdx.x, blockIdx.y));
// if (cute::block(0)) {
if (blockIdx.x == 0 && blockIdx.y == 0)
cute::print_tensor(gmem_tensor_coord_cta);
// }
initialize_barrier(tma_load_mbar, /* arrival count */ 1);
set_barrier_transaction_bytes(tma_load_mbar, tma_transaction_bytes);
auto tma_load_per_cta = tma_load.get_slice(0);
copy(tma_load.with(tma_load_mbar),
tma_load_per_cta.partition_S(gmem_tensor_coord_cta),
tma_load_per_cta.partition_D(smem_tensor));
}
__syncthreads();
wait_barrier(tma_load_mbar, /* phase */ 0);
// after this line, the TMA load is finished
if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0) {
printf("DEBUG SMEM\n");
for (int i = 0; i < CTA_M; i++) {
for (int j = 0; j < CTA_N; j++) {
printf("%.0f ", smem_tensor(i, j));
}
printf("\n");
}
}
__syncthreads();
if (threadIdx.x == 0) {
for (int i = 0; i < CTA_M; i++) {
for (int j = 0; j < CTA_N; j++) {
smem_tensor(i, j) += smem_tensor(j, i);
}
}
*(float*)(gmem_tensor.data().get()) = smem_tensor(0, 0);
}
}
template <typename T, int CTA_M, int CTA_N>
void host_fn(T* data, int M, int N) {
using namespace cute;
T* ddata;
CHECK_CUDA2(cudaMalloc(&ddata, sizeof(T) * M * N));
CHECK_CUDA2(cudaMemcpy(ddata, data, sizeof(T) * M * N, cudaMemcpyHostToDevice));
// create the GMEM tensor
auto gmem_layout = make_layout(make_shape(M, N), LayoutRight{});
auto gmem_tensor = make_tensor(make_gmem_ptr((T*)ddata), gmem_layout);
using GmemTensor = decltype(gmem_tensor);
// create the SMEM layout
//! smem layout must be set during compile
auto smem_layout = make_layout(make_shape(Int<CTA_M>{}, Int<CTA_N>{}), LayoutRight{});
// create the TMA object
using tma_class = decltype(make_tma_copy(SM90_TMA_LOAD{}, gmem_tensor, smem_layout));
auto tma_load = tma_class{};
// auto tma_load = make_tma_copy(SM90_TMA_LOAD{}, gmem_tensor, smem_layout);
// invoke the kernel
tma_load_kernel<T, CTA_M, CTA_N, tma_class, GmemTensor>
<<<dim3{M / CTA_M, N / CTA_N, 1}, 32>>>
(tma_load, gmem_tensor);
CHECK_CUDA2(cudaDeviceSynchronize());
CHECK_CUDA2(cudaGetLastError());
}
int main() {
#define data_M 64
#define data_N 64
float* data = (float*)malloc(data_M * data_N * sizeof(float));
for (int i = 0; i < data_M * data_N; i++) {
data[i] = i;
}
host_fn<float, 8, 8>(data, data_M, data_N);
}
The runtime errors detected by compute-sanitizer
are:
========= Illegal instruction
========= at cute::SM90_TMA_LOAD_2D::copy(const void *, unsigned long *, void *, const int &, const int &)+0x2fa0 in /data/home/tester/zkg/learn_kernel/cutlass/include/cute/arch/copy_sm90_tma.hpp:108
========= by thread (1,0,0) in block (0,1,0)
========= Device Frame:cute::SM90_TMA_LOAD::copy(const void *, unsigned long *, void *, const int &, const int &)+0x2db0 in /data/home/tester/zkg/learn_kernel/cutlass/include/cute/arch/copy_sm90_tma.hpp:286
========= Device Frame:void cute::detail::CallCOPY<cute::SM90_TMA_LOAD_OP>::operator ()<const CUtensorMap_st *const &, unsigned long *const &, void *&, unsigned int &, unsigned int &>(T1 &&...) const+0x2db0 in /data/home/tester/zkg/learn_kernel/cutlass/include/cute/arch/util.hpp:159
========= Device Frame:void cute::detail::explode_tuple<cute::detail::CallCOPY<cute::SM90_TMA_LOAD_OP>, const cute::tuple<const CUtensorMap_st *, unsigned long *> &, (int)0, (int)1, cute::tuple<void *>, (int)0, cute::ArithmeticTuple<unsigned int, unsigned int> &, (int)0, (int)1>(T1, T2 &&, std::integer_sequence<int, T3>, T4 &&, std::integer_sequence<int, T5>, T6 &&, std::integer_sequence<int, T7>)+0x2db0 in /data/home/tester/zkg/learn_kernel/cutlass/include/cute/arch/util.hpp:293
...
========= Device Frame:void tma_load_kernel<float, (int)8, (int)8, cute::TiledCopy<cute::Copy_Atom<cute::Copy_Traits<cute::SM90_TMA_LOAD, cute::C<(int)2048>, cute::AuxTmaParams<cute::tuple<cute::ScaledBasis<cute::C<(int)1>, (int)1>, cute::ScaledBasis<cute::C<(int)1>, (int)0>>, const cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)8>>, cute::tuple<cute::ScaledBasis<cute::C<(int)1>, (int)1>, cute::ScaledBasis<cute::C<(int)1>, (int)0>>> &, const cute::Swizzle<(int)0, (int)4, (int)3> &>>, float>, cute::Layout<cute::tuple<cute::C<(int)1>, cute::tuple<cute::tuple<cute::C<(int)8>, cute::C<(int)8>>>>, cute::tuple<cute::C<(int)0>, cute::tuple<cute::tuple<cute::C<(int)8>, cute::C<(int)1>>>>>, cute::tuple<cute::C<(int)8>, cute::C<(int)8>>>, cute::Tensor<cute::ViewEngine<cute::gmem_ptr<float *>>, cute::Layout<cute::tuple<int, int>, cute::tuple<int, cute::C<(int)1>>>>>(T4, T5)+0x2db0 in /xxx/learn_kernel/my_learn_tma/tma_cutlass_tutorial/demo.cu:60
...
However, in my code only threadIdx.x == 0
executes the copy()
function. I’m confused why threadIdx.x == 1
triggers the illegal instruction error.
Has anyone experienced something similar? Any idea what might be going wrong?
Thanks in advance!