Out of bounds global memory write but inside allocation boundary

I’m writing a kernel for pytorch where it shows the following error under compute-sanitizer --tool memcheck

========= Invalid __global__ write of size 16 bytes
=========     at 0xa750 in /home/sean/myrepo/packages/state_kernel/csrc/cutlass/include/cute/arch/copy.hpp:56:void cute::UniversalCopy<cutlass::uint128_t, cutlass::uint128_t>::copy<cutlass::uint128_t, cutlass::uint128_t>(const T1 &, T2 &)
=========     by thread (67,0,0) in block (11,10,106)
=========     Address 0x7f3cfc2eac30 is out of bounds
=========     and is inside the nearest allocation at 0x7f3cbe000000 of size 4,989,124,608 bytes
=========     Device Frame:/home/sean/myrepo/packages/state_kernel/csrc/cutlass/include/cute/algorithm/copy.hpp:136:void cute::copy_if<cute::TrivialPredTensor, cute::ViewEngine<cute::smem_ptr<const cutlass::uint128_t *>>, cute::Layout<cute::tuple<cute::tuple<cute::C<(int)1>, cute::C<(int)1>>, cute::C<(int)1>, cute::C<(int)1>>, cute::tuple<cute::tuple<cute::C<(int)0>, cute::C<(int)1>>, cute::C<(int)0>, cute::C<(int)0>>>, cute::ViewEngine<cute::gmem_ptr<cutlass::uint128_t *>>, cute::Layout<cute::tuple<cute::tuple<cute::C<(int)1>, cute::C<(int)1>>, cute::C<(int)1>, cute::C<(int)1>>, cute::tuple<cute::tuple<cute::C<(int)0>, cute::C<(int)1>>, cute::C<(int)0>, cute::C<(int)0>>>>(const T1 &, const cute::Tensor<T2, T3> &, cute::Tensor<T4, T5> &) [0xa6a0]
=========     Device Frame:/home/sean/myrepo/packages/state_kernel/csrc/cutlass/include/cute/algorithm/copy.hpp:221:void cute::copy_vec<cutlass::uint128_t, cute::ViewEngine<cute::smem_ptr<cutlass::bfloat16_t *>>, cute::Layout<cute::tuple<cute::tuple<cute::C<(int)1>, cute::C<(int)8>>, cute::C<(int)1>, cute::C<(int)1>>, cute::tuple<cute::tuple<cute::C<(int)0>, cute::C<(int)1>>, cute::C<(int)0>, cute::C<(int)0>>>, cute::ViewEngine<cute::gmem_ptr<cutlass::bfloat16_t *>>, cute::Layout<cute::tuple<cute::tuple<cute::C<(int)1>, cute::C<(int)8>>, cute::C<(int)1>, cute::C<(int)1>>, cute::tuple<cute::tuple<cute::C<(int)0>, cute::C<(int)1>>, cute::C<(int)0>, cute::C<(int)0>>>>(const cute::Tensor<T2, T3> &, cute::Tensor<T4, T5> &) [0xa6a0]
=========     Device Frame:/home/sean/myrepo/packages/state_kernel/csrc/cutlass/include/cute/algorithm/copy.hpp:283:void cute::copy<(int)128, , cute::ViewEngine<cute::smem_ptr<cutlass::bfloat16_t *>>, cute::Layout<cute::tuple<cute::tuple<cute::C<(int)1>, cute::C<(int)8>>, cute::C<(int)1>, cute::C<(int)1>>, cute::tuple<cute::tuple<cute::C<(int)0>, cute::C<(int)1>>, cute::C<(int)0>, cute::C<(int)0>>>, cute::ViewEngine<cute::gmem_ptr<cutlass::bfloat16_t *>>, cute::Layout<cute::tuple<cute::tuple<cute::C<(int)1>, cute::C<(int)8>>, cute::C<(int)1>, cute::C<(int)1>>, cute::tuple<cute::tuple<cute::C<(int)0>, cute::C<(int)1>>, cute::C<(int)0>, cute::C<(int)0>>>>(const cute::AutoVectorizingCopyWithAssumedAlignment<T1> &, const cute::Tensor<T3, T4> &, cute::Tensor<T5, T6> &) [0xa6a0]
=========     Device Frame:/home/sean/myrepo/packages/state_kernel/csrc/cutlass/include/cute/algorithm/copy.hpp:299:void cute::copy<cute::ViewEngine<cute::smem_ptr<cutlass::bfloat16_t *>>, cute::Layout<cute::tuple<cute::tuple<cute::C<(int)1>, cute::C<(int)8>>, cute::C<(int)1>, cute::C<(int)1>>, cute::tuple<cute::tuple<cute::C<(int)0>, cute::C<(int)1>>, cute::C<(int)0>, cute::C<(int)0>>>, cute::ViewEngine<cute::gmem_ptr<cutlass::bfloat16_t *>>, cute::Layout<cute::tuple<cute::tuple<cute::C<(int)1>, cute::C<(int)8>>, cute::C<(int)1>, cute::C<(int)1>>, cute::tuple<cute::tuple<cute::C<(int)0>, cute::C<(int)1>>, cute::C<(int)0>, cute::C<(int)0>>>>(const cute::Tensor<T1, T2> &, cute::Tensor<T3, T4> &) [0xa6a0]
=========     Device Frame:/home/sean/myrepo/packages/state_kernel/csrc/state_kernel/src/kernel.h:319:void state_kernel::chunk_state_kernel_fwd<State_chunk_traits<cutlass::bfloat16_t, (int)32, (int)4, (int)52384, (int)32, (int)4, (int)128>>(Chunk_state_params) [0xa6a0]

I’ve identified the offending code

    // Represent the full tensors for O
    Tensor mO = make_tensor(
        make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr)),
        make_shape(params.batch_size, params.num_chunks, params.num_heads, paddedExpandedDim, Headdim),
        GenRowMajor{}
        ); // (batch_size, num_chunks, num_heads, paddedExpandedDim, Headdim)
    Tensor gO = local_tile(mO(batch_id, chunk_seq_id, head_id, _, _), Shape<Int<BlockD>, Int<Headdim>>{}, make_coord(dim_id, 0)); // (BlockD, Headdim)
...
    Tensor sO = make_tensor(
        sPK.data() + size(sPK),
        typename Kernel_traits::SmemLayoutO{}); // (BlockD, Headdim)
...
    // copy back to global memory
    typename Kernel_traits::GmemCopyTileO gmem_tiled_copy_O;
    auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tid);

    Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
    Tensor tOsO = gmem_thr_copy_O.partition_S(sO);
    cute::copy(tOsO, tOgO);

So essentially OOB happens when writing to output of this kernel. The output was allocated this way

    auto out = torch::empty({batch_size, num_chunks, num_heads, c_tensor.size(0), head_size}, k.options());

I’ve verified that the allocated tensor is contiguous and has the right shape and strides

    std::cout << "out shape: " << out.sizes() << std::endl;
    std::cout << "out stride: " << out.strides() << std::endl;
    std::cout << "out contiguous: " << out.is_contiguous() << std::endl;

which gives

out shape: [4, 31, 12, 52384, 32]
out stride: [623579136, 20115456, 1676288, 32, 1]
out contiguous: 1

This error only happens when batch size (the first dimension) is >= 4, which is even more strange. How should one proceed from here? The error suggests that the address is OOB but within allocated memory, which I thought is only explainable by non-contiguous allocation, but apparently it’s not the case…

The cause was using int for offset, which is not enough for indexing large matrices.

1 Like