How to make nvcc place variables in register instead of local memory when there's clearly enough space?

I have a kernel which is using 97 registers with a headroom of 71 registers (obtained from nsight compute, as shown below)

What bothers me is that I have a cute::Tensor that I’d like to be placed in register file, but the compiler decides to place them on local memory, as indicated by both the compilation output and nsight compute profiling output

    352 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info    : Used 126 registers, 25344 bytes smem, 528 bytes cmem[0]

I know that dynamically indexing array is not supported in register files so I have some template programming to statically determine the index.

...
#define BINARY_DIM_SWITCH_2(VALUE, CONST_NAME, LAMBDA) \
    if (VALUE == 0)                                    \
    {                                                  \
        constexpr static int CONST_NAME = 0;           \
        LAMBDA();                                      \
    }                                                  \
    else                                               \
    {                                                  \
        constexpr static int CONST_NAME = 1;           \
        LAMBDA();                                      \
    }

#define BINARY_DIM_SWITCH_4(VALUE, CONST_NAME, LAMBDA)                  \
    if (VALUE < 2)                                                      \
    {                                                                   \
        BINARY_DIM_SWITCH_2(VALUE, CONST_NAME, LAMBDA);                 \
    }                                                                   \
    else                                                                \
    {                                                                   \
        BINARY_DIM_SWITCH_2(VALUE - 2, CONST_NAME##_offset, [&]() {     \
            constexpr static int CONST_NAME = CONST_NAME##_offset + 2;  \
            LAMBDA(); }); \
    }

#define BINARY_DIM_SWITCH_8(VALUE, CONST_NAME, LAMBDA)                  \
    if (VALUE < 4)                                                      \
    {                                                                   \
        BINARY_DIM_SWITCH_4(VALUE, CONST_NAME, LAMBDA);                 \
    }                                                                   \
    else                                                                \
    {                                                                   \
        BINARY_DIM_SWITCH_4(VALUE - 4, CONST_NAME##_offset, [&]() {     \
            constexpr static int CONST_NAME = CONST_NAME##_offset + 4;  \
            LAMBDA(); }); \
    }

#define BINARY_DIM_SWITCH_16(VALUE, CONST_NAME, LAMBDA)                 \
    if (VALUE < 8)                                                      \
    {                                                                   \
        BINARY_DIM_SWITCH_8(VALUE, CONST_NAME, LAMBDA);                 \
    }                                                                   \
    else                                                                \
    {                                                                   \
        BINARY_DIM_SWITCH_8(VALUE - 8, CONST_NAME##_offset, [&]() {     \
            constexpr static int CONST_NAME = CONST_NAME##_offset + 8;  \
            LAMBDA(); }); \
    }

#define BINARY_DIM_SWITCH_32(VALUE, CONST_NAME, LAMBDA)                   \
    if (VALUE < 16)                                                       \
    {                                                                     \
        BINARY_DIM_SWITCH_16(VALUE, CONST_NAME, LAMBDA);                  \
    }                                                                     \
    else                                                                  \
    {                                                                     \
        BINARY_DIM_SWITCH_16(VALUE - 16, CONST_NAME##_offset, [&]() {     \
            constexpr static int CONST_NAME = CONST_NAME##_offset + 16;   \
            LAMBDA(); }); \
    }

#define BINARY_DIM_SWITCH_64(VALUE, CONST_NAME, LAMBDA)                   \
    if (VALUE < 32)                                                       \
    {                                                                     \
        BINARY_DIM_SWITCH_32(VALUE, CONST_NAME, LAMBDA);                  \
    }                                                                     \
    else                                                                  \
    {                                                                     \
        BINARY_DIM_SWITCH_32(VALUE - 32, CONST_NAME##_offset, [&]() {     \
            constexpr static int CONST_NAME = CONST_NAME##_offset + 32;   \
            LAMBDA(); }); \
    }

#define BINARY_DIM_SWITCH(VALUE, CONST_NAME, HEADDIM, LAMBDA)                                                                                                        \
    if (HEADDIM == 1 && VALUE == 0)                                                                                                                                  \
    {                                                                                                                                                                \
        constexpr static int CONST_NAME = 0;                                                                                                                         \
        LAMBDA();                                                                                                                                                    \
    }                                                                                                                                                                \
    else if (HEADDIM == 2)                                                                                                                                           \
    {                                                                                                                                                                \
        BINARY_DIM_SWITCH_2(VALUE, CONST_NAME, LAMBDA);                                                                                                              \
    }                                                                                                                                                                \
    else if (HEADDIM == 4)                                                                                                                                           \
    {                                                                                                                                                                \
        BINARY_DIM_SWITCH_4(VALUE, CONST_NAME, LAMBDA);                                                                                                              \
    }                                                                                                                                                                \
    else if (HEADDIM == 8)                                                                                                                                           \
    {                                                                                                                                                                \
        BINARY_DIM_SWITCH_8(VALUE, CONST_NAME, LAMBDA);                                                                                                              \
    }                                                                                                                                                                \
    else if (HEADDIM == 16)                                                                                                                                          \
    {                                                                                                                                                                \
        BINARY_DIM_SWITCH_16(VALUE, CONST_NAME, LAMBDA);                                                                                                             \
    }                                                                                                                                                                \
    else if (HEADDIM == 32)                                                                                                                                          \
    {                                                                                                                                                                \
        BINARY_DIM_SWITCH_32(VALUE, CONST_NAME, LAMBDA);                                                                                                             \
    }                                                                                                                                                                \
    else if (HEADDIM == 64)                                                                                                                                          \
    {                                                                                                                                                                \
        BINARY_DIM_SWITCH_64(VALUE, CONST_NAME, LAMBDA);                                                                                                             \
    }                                                                                                                                                                \
    else                                                                                                                                                             \
    {                                                                                                                                                                \
        static_assert(HEADDIM == 1 || HEADDIM == 2 || HEADDIM == 4 || HEADDIM == 8 || HEADDIM == 16 || HEADDIM == 32 || HEADDIM == 64, "Unsupported HEADDIM value"); \
    }

...

template <int DIM, typename Tensor, typename T>
__forceinline__ __device__ void static_add(Tensor &arr, T val)
{
    arr[DIM] += val;
}
...
auto rdQK_acc = make_tensor<ElementAccum>(Shape<Int<Headdim>>{}));
...
BINARY_DIM_SWITCH(rD(i_mod_stage_D, d), DIM, Headdim, [&]()
                                      { static_add<DIM>(rdQK_acc, rdQK_acc_buffer[d]); });
...

where the first bits of code defines a macro for doing switch case on the index, so I think compiler should be able to resolve the index statically.

Despite the effort, the tensor is still placed in local memory. What am I missing? Is this possibly cute related?

Sanity check: Are you looking at code generated by a release build with full optimization?

Thread-local data objects are placed into registers as part of optimizations. The compiler does this based on heuristics which are not disclosed to the general public, and subject to change without notice at any time. By observation, the compiler appears to have a notion of trading off faster access with occupancy, and usually makes good decisions regarding the trade-off. I see several possibilities:

(1) Despite the use of template meta-programming, the compiler is unable to resolve all accesses to the data object to fixed addresses at compile time. Examining the generated PTX may provide clues as to what the residual run-time decisions are.

(2) The compiler is able to resolve all accesses to the data object to fixed addresses at compile time. However, the data object’s size exceeds the limit on object size set by the compiler’s heuristic.

(3) The compiler is able to resolve all accesses to to the data object to fixed addresses at compile time and satisfies the object-size heuristic of the compiler. However, another compiler heuristic determines that using additional registers will likely decrease occupancy significantly enough to create a net-negative impact on performance.

You may want to try to reduced variants of your current code so see at which point the compiler response changes in the way you desire. Depending on the outcome of your investigation, you may ultimately decide to file an enhancement request with NVIDIA.

What does “cute” refer to here?

1 Like

If the stack frame only contains the tensor, 352 bytes would equal 88 registers which is greater than your determined headroom of 71 registers.

1 Like

here

1 Like

Thanks @njuffa ! Cute is the newest cutlass abstraction.

(1) I’m looking at the PTX code but I don’t know what specifically I should be looking for. I’ve identified the offending code that does local memory load and store as shown here

Here’s the start of the ptx code

Fatbin ptx code:
================
arch = sm_80
code version = [8,2]
host = linux
compile_size = 64bit
compressed








.version 8.2
.target sm_80
.address_size 64



.global .align 4 .u32 _ZZN12state_kernel22query_state_kernel_fwdI18Query_state_traitsIN7cutlass10bfloat16_tELi32ELi4ELi52384ELi32ELi4ELi128E18State_c
hunk_traitsIS3_Li32ELi4ELi52384ELi32ELi4ELi128EEEEEv18Query_state_paramsE6stages = 2;
.global .align 1 .b8 _ZN79_INTERNAL_10e41dc4_40_query_state_fwd_bf16_hdim32_deg4_sm80_cu_a583c0e2_30990266thrust6system6detail10sequential3seqE[1];
.global .align 1 .b8 _ZN79_INTERNAL_10e41dc4_40_query_state_fwd_bf16_hdim32_deg4_sm80_cu_a583c0e2_30990264cuda3std3__48in_placeE[1];
.global .align 1 .b8 _ZN79_INTERNAL_10e41dc4_40_query_state_fwd_bf16_hdim32_deg4_sm80_cu_a583c0e2_30990264cuda3std6ranges3__45__cpo4swapE[1];
.global .align 1 .b8 _ZN79_INTERNAL_10e41dc4_40_query_state_fwd_bf16_hdim32_deg4_sm80_cu_a583c0e2_30990264cute1_E[1];
.global .align 1 .b8 _ZN79_INTERNAL_10e41dc4_40_query_state_fwd_bf16_hdim32_deg4_sm80_cu_a583c0e2_30990264cute7productE[1];
.extern .shared .align 16 .b8 _ZN12state_kernel5smem_E[];

.visible .entry _ZN12state_kernel22simple_query_state_fwdI25Simple_query_state_traitsIN7cutlass10bfloat16_tELi32ELi4ELi52384ELi32ELi4E25Simple_state_
chunk_traitsIS3_Li32ELi4ELi52384ELi32ELi4EEEEEv18Query_state_params(
.param .align 8 .b8 _ZN12state_kernel22simple_query_state_fwdI25Simple_query_state_traitsIN7cutlass10bfloat16_tELi32ELi4ELi52384ELi32ELi4E25Simple_st
ate_chunk_traitsIS3_Li32ELi4ELi52384ELi32ELi4EEEEEv18Query_state_params_param_0[160]
)
.maxntid 128, 1, 1
.minnctapersm 1
{
.reg .pred %p<14>;
.reg .b16 %rs<121>;
.reg .f32 %f<185>;
.reg .b32 %r<509>;
.reg .b64 %rd<68>;
...

Here’s more relevant calls to ld.local and so on

looking at the ld.local calls, it seems like some runtime address calculation is happening, does that mean despite the meta-programming, the compiler is unable to resolve all access to the data array? But the C++ code is literally calling static_add<DIM>(array, value) where DIM is the index. Why wouldn’t the compiler be able to determine the index?

(2) Is there a place where such heuristics is documented (even roughly)? I’ve seen people saying 16 registers is the limit but I’m not sure.

(3) If that’s the case yeah it seems like I should file a enhancement request

and yes, I’m using a release build with full optimization

It’s possible. Is there a way to distinguishing what stack frame bytes are allocated to?

Looking at the PTX code I find it hard to align with what I see in nsight-compute

.visible .entry _ZN12state_kernel22simple_query_state_fwdI25Simple_query_state_traitsIN7cutlass10bfloat16_tELi32ELi4ELi52384ELi32ELi4E25Simple_state_chunk_traitsIS3_Li32ELi4ELi52384ELi32ELi4EEEEEv18Query_state_params(
.param .align 8 .b8 _ZN12state_kernel22simple_query_state_fwdI25Simple_query_state_traitsIN7cutlass10bfloat16_tELi32ELi4ELi52384ELi32ELi4E25Simple_state_chunk_traitsIS3_Li32ELi4ELi52384ELi32ELi4EEEEEv18Query_state_params_param_0[160]
)
.maxntid 128, 1, 1
.minnctapersm 1
{
.reg .pred %p<14>;
.reg .b16 %rs<121>;
.reg .f32 %f<185>;
.reg .b32 %r<509>;
.reg .b64 %rd<68>;

here I’m seeing register allocation of 14 predicates, 121 bf16, 185 f32, 509 b32, and 68 b64 registers, which looks nothing like the 97 registers I see from nsight compute.

PTX isn’t the right way to correlate with info from nsight compute

You’ll need to get familiar (somewhat) with SASS in order to come to grips with actual register usage.

I’m puzzled by the lack of clarity on the tensor size. Its your tensor, isn’t it?

How big is that cute::Tensor ?

To “place an array in a register file”, there are at least 2 things that need to be satisfied:

  1. All indexing must be discoverable/computable at compile time.
  2. The array cannot be too large

I think njuffa was suggesting PTX study to go after item 1, i.e. to try and discover what the compiler might be having trouble with (if any) from an indexing perspective.

With respect to 2, PTX is a virtual architecture and intermediate code. It has an “arbitrary” number of registers. So from a size perspective, the actual decision to locate data “in registers” must be done after the PTX creation phase. That decision cannot be done/completed until the PTX is compiled to SASS.

A thread has a hardware limit of 255 32 bit (SASS) registers, so that naturally upper-bounds the maximum size of any such array, but in practice you cannot achieve that, and perhaps might not want to locate very large arrays in registers anyway, because of side effects e.g. on achievable occupancy. Register usage is not “free”. The total number of registers available to a threadblock is also a (SM) hardware limit, so using more registers per thread will eventually lead to a point where the number of threads per threadblock is limited, and also will eventually lead to a point where the number of threads per SM is limited, which can be a perf limiter.

1 Like

Thanks a lot for your response @Robert_Crovella !

How big is that cute::Tensor ?

The tensor is of size 32 (f32) or 64, depending on some static parameters, created in this way

auto rdQK_acc = cute::make_tensor<float>(Shape<Int<32>>{});

I had thought that this size isn’t that large because I only have 4 warps at maximum and nsight compute tells me that there’s enough headroom (71) at the current compiled result.

  1. All indexing must be discoverable/computable at compile time.

I’m aware of this and the indexing are all computable at compile time, as evident by the usage of this template function

template <int DIM, typename Tensor, typename T>
__forceinline__ __device__ void static_add(Tensor &arr, T val)
{
    arr[DIM] += val;
}
...

BINARY_DIM_SWITCH(rD(i_mod_stage_D, d), DIM, Headdim, [&]()
                                      { static_add<DIM>(rdQK_acc, rdQK_acc_buffer[d]); });

where BINARY_DIM_SWITCH is just a bunch of switch case.

How should one look at the PTX output and determine if the compile has trouble with indexing it? I’m seeing PTX output like this

ld.local.u32 %r225, [%rd81+24];
add.s32 %r226, %r225, %r197;

Does this mean the compiler had trouble determining the index statically?

The array cannot be too large

But how large is too large? Perhaps the size limit is some compiler heuristics but what factors are important here?

Please let me know if there’s anything else I should provide to make the problem clearer!

%rd81 is presumably a pointer to the start of the data object. We have to find out where that comes from, and that should indicate how the compiler has transformed the indexing computation. Showing code as an image is always a BAD idea, we would like something searchable. The good thing about nvcc generated PTX code is that it is in SSA form, meaning each virtual register is written to exactly once.

Best I can make it out, %rd81 is a pointer computed at runtime, possibly inside a loop if one assumes $L_881_8 is a loop-starting label. From what I can see here, this would point in the direction of my item (1): The addressing does not seem to be fully resolvable at compile time. One would want to correlate the generated PTX with the source code to try and figure out the details.

Has the size issue pointed out by @striker159 been resolved?

As I stated, compiler heuristics are not documented publicly and can change at any time. That means we do not know a numeric value for “too large”. One could try to reverse engineer it, for a particular version of the compiler.

If we want to try and distinguish between a size issue and an indexing issue, we would, as an experiment, reduce the size of the data object (in steps) to see whether this results in the desired code.

Somewhere earlier in the thread it was mentioned that the indexing involves a switch statement. I am not a compiler engineer, but wonder how a compiler can safely reason through a switch statement. It seems more challenging than “flattening” normal indexing arithmetic. One might hypothesize that the compiler needs at least a complete mapping for that. The presence of a default label (and whether it is aliased with any other label, and if so, which one) may play into that. Seems worthy of exploration.

1 Like

I was able to remove all the local memory usage by removing shared memory prefetching in my kernel (which takes up some register space)

$ cuobjdump -ptx chunk_state_bwd_fp16_hdim32_deg4_sm80.o | grep ld.local
$

what I notice in this version of PTX code is the appearance of the static index offset as global variables that is used later to statically determine the index.

$ cuobjdump -ptx chunk_state_bwd_fp16_hdim32_deg4_sm80.o | head -n 40

Fatbin ptx code:
================
arch = sm_80
code version = [8,2]
host = linux
compile_size = 64bit
compressed








.version 8.2
.target sm_80
.address_size 64


.global .align 4 .u32 _ZZZN12state_kernel11SympowStateILi1ELi1ELb1ELb1E22Chunk_state_bwd_traitsIN7cutlass6half_tELi32ELi4ELi52384ELi32ELi4ELi128E18State_chunk_traitsIS3_Li32ELi4ELi52384ELi32ELi4ELi128EEELb0EE14backpropColumnIN4cute6TensorINS9_10ViewEngineINS9_8smem_ptrIPiEEEENS9_6LayoutINS9_5tupleIJNS9_1CILi32EEENSI_ILi4EEEEEENSH_IJSK_NSI_ILi1EEEEEEEEEENSA_ISF_NSG_INSH_IJSJ_EEENSH_IJSM_EEEEEEENSA_INSB_INSC_IPS3_EEEENS9_14ComposedLayoutINS9_7SwizzleILi2ELi3ELi3EEENSI_ILi0EEENSG_INSH_IJNSI_ILi128EEESJ_EEENSH_IJSJ_SM_EEEEEEEEENSA_ISW_NSX_INSY_ILi3ELi3ELi3EEES10_S14_EEEELb1EvEEviT_T0_T1_T2_ENKUlvE4_clEvE3DIM = 2;
.global .align 4 .u32 _ZZZN12state_kernel11SympowStateILi1ELi1ELb1ELb1E22Chunk_state_bwd_traitsIN7cutlass6half_tELi32ELi4ELi52384ELi32ELi4ELi128E18State_chunk_traitsIS3_Li32ELi4ELi52384ELi32ELi4ELi128EEELb0EE14backpropColumnIN4cute6TensorINS9_10ViewEngineINS9_8smem_ptrIPiEEEENS9_6LayoutINS9_5tupleIJNS9_1CILi32EEENSI_ILi4EEEEEENSH_IJSK_NSI_ILi1EEEEEEEEEENSA_ISF_NSG_INSH_IJSJ_EEENSH_IJSM_EEEEEEENSA_INSB_INSC_IPS3_EEEENS9_14ComposedLayoutINS9_7SwizzleILi2ELi3ELi3EEENSI_ILi0EEENSG_INSH_IJNSI_ILi128EEESJ_EEENSH_IJSJ_SM_EEEEEEEEENSA_ISW_NSX_INSY_ILi3ELi3ELi3EEES10_S14_EEEELb1EvEEviT_T0_T1_T2_ENKUlvE5_clEvE3DIM = 3;
.global .align 4 .u32 _ZZZN12state_kernel11SympowStateILi1ELi1ELb1ELb1E22Chunk_state_bwd_traitsIN7cutlass6half_tELi32ELi4ELi52384ELi32ELi4ELi128E18State_chunk_traitsIS3_Li32ELi4ELi52384ELi32ELi4ELi128EEELb0EE14backpropColumnIN4cute6TensorINS9_10ViewEngineINS9_8smem_ptrIPiEEEENS9_6LayoutINS9_5tupleIJNS9_1CILi32EEENSI_ILi4EEEEEENSH_IJSK_NSI_ILi1EEEEEEEEEENSA_ISF_NSG_INSH_IJSJ_EEENSH_IJSM_EEEEEEENSA_INSB_INSC_IPS3_EEEENS9_14ComposedLayoutINS9_7SwizzleILi2ELi3ELi3EEENSI_ILi0EEENSG_INSH_IJNSI_ILi128EEESJ_EEENSH_IJSJ_SM_EEEEEEEEENSA_ISW_NSX_INSY_ILi3ELi3ELi3EEES10_S14_EEEELb1EvEEviT_T0_T1_T2_ENKUlvE8_clEvE3DIM = 2;
.global .align 4 .u32 _ZZZN12state_kernel11SympowStateILi1ELi1ELb1ELb1E22Chunk_state_bwd_traitsIN7cutlass6half_tELi32ELi4ELi52384ELi32ELi4ELi128E18State_chunk_traitsIS3_Li32ELi4ELi52384ELi32ELi4ELi128EEELb0EE14backpropColumnIN4cute6TensorINS9_10ViewEngineINS9_8smem_ptrIPiEEEENS9_6LayoutINS9_5tupleIJNS9_1CILi32EEENSI_ILi4EEEEEENSH_IJSK_NSI_ILi1EEEEEEEEEENSA_ISF_NSG_INSH_IJSJ_EEENSH_IJSM_EEEEEEENSA_INSB_INSC_IPS3_EEEENS9_14ComposedLayoutINS9_7SwizzleILi2ELi3ELi3EEENSI_ILi0EEENSG_INSH_IJNSI_ILi128EEESJ_EEENSH_IJSJ_SM_EEEEEEEEENSA_ISW_NSX_INSY_ILi3ELi3ELi3EEES10_S14_EEEELb1EvEEviT_T0_T1_T2_ENKUlvE9_clEvE3DIM = 3;
.global .align 4 .u32 _ZZZN12state_kernel11SympowStateILi1ELi1ELb1ELb1E22Chunk_state_bwd_traitsIN7cutlass6half_tELi32ELi4ELi52384ELi32ELi4ELi128E18State_chunk_traitsIS3_Li32ELi4ELi52384ELi32ELi4ELi128EEELb0EE14backpropColumnIN4cute6TensorINS9_10ViewEngineINS9_8smem_ptrIPiEEEENS9_6LayoutINS9_5tupleIJNS9_1CILi32EEENSI_ILi4EEEEEENSH_IJSK_NSI_ILi1EEEEEEEEEENSA_ISF_NSG_INSH_IJSJ_EEENSH_IJSM_EEEEEEENSA_INSB_INSC_IPS3_EEEENS9_14ComposedLayoutINS9_7SwizzleILi2ELi3ELi3EEENSI_ILi0EEENSG_INSH_IJNSI_ILi128EEESJ_EEENSH_IJSJ_SM_EEEEEEEEENSA_ISW_NSX_INSY_ILi3ELi3ELi3EEES10_S14_EEEELb1EvEEviT_T0_T1_T2_ENKUlvE10_clEvE3DIM = 4;
.global .align 4 .u32 _ZZZN12state_kernel11SympowStateILi1ELi1ELb1ELb1E22Chunk_state_bwd_traitsIN7cutlass6half_tELi32ELi4ELi52384ELi32ELi4ELi128E18State_chunk_traitsIS3_Li32ELi4ELi52384ELi32ELi4ELi128EEELb0EE14backpropColumnIN4cute6TensorINS9_10ViewEngineINS9_8smem_ptrIPiEEEENS9_6LayoutINS9_5tupleIJNS9_1CILi32EEENSI_ILi4EEEEEENSH_IJSK_NSI_ILi1EEEEEEEEEENSA_ISF_NSG_INSH_IJSJ_EEENSH_IJSM_EEEEEEENSA_INSB_INSC_IPS3_EEEENS9_14ComposedLayoutINS9_7SwizzleILi2ELi3ELi3EEENSI_ILi0EEENSG_INSH_IJNSI_ILi128EEESJ_EEENSH_IJSJ_SM_EEEEEEEEENSA_ISW_NSX_INSY_ILi3ELi3ELi3EEES10_S14_EEEELb1EvEEviT_T0_T1_T2_ENKUlvE11_clEvE3DIM = 5;
.global .align 4 .u32 _ZZZZN12state_kernel11SympowStateILi1ELi1ELb1ELb1E22Chunk_state_bwd_traitsIN7cutlass6half_tELi32ELi4ELi52384ELi32ELi4ELi128E18State_chunk_traitsIS3_Li32ELi4ELi52384ELi32ELi4ELi128EEELb0EE14backpropColumnIN4cute6TensorINS9_10ViewEngineINS9_8smem_ptrIPiEEEENS9_6LayoutINS9_5tupleIJNS9_1CILi32EEENSI_ILi4EEEEEENSH_IJSK_NSI_ILi1EEEEEEEEEENSA_ISF_NSG_INSH_IJSJ_EEENSH_IJSM_EEEEEEENSA_INSB_INSC_IPS3_EEEENS9_14ComposedLayoutINS9_7SwizzleILi2ELi3ELi3EEENSI_ILi0EEENSG_INSH_IJNSI_ILi128EEESJ_EEENSH_IJSJ_SM_EEEEEEEEENSA_ISW_NSX_INSY_ILi3ELi3ELi3EEES10_S14_EEEELb1EvEEviT_T0_T1_T2_ENKUlvE12_clEvENKUlvE_clEvE3DIM = 6;
.global .align 4 .u32 _ZZZN12state_kernel11SympowStateILi1ELi1ELb1ELb1E22Chunk_state_bwd_traitsIN7cutlass6half_tELi32ELi4ELi52384ELi32ELi4ELi128E18State_chunk_traitsIS3_Li32ELi4ELi52384ELi32ELi4ELi128EEELb0EE14backpropColumnIN4cute6TensorINS9_10ViewEngineINS9_8smem_ptrIPiEEEENS9_6LayoutINS9_5tupleIJNS9_1CILi32EEENSI_ILi4EEEEEENSH_IJSK_NSI_ILi1EEEEEEEEEENSA_ISF_NSG_INSH_IJSJ_EEENSH_IJSM_EEEEEEENSA_INSB_INSC_IPS3_EEEENS9_14ComposedLayoutINS9_7SwizzleILi2ELi3ELi3EEENSI_ILi0EEENSG_INSH_IJNSI_ILi128EEESJ_EEENSH_IJSJ_SM_EEEEEEEEENSA_ISW_NSX_INSY_ILi3ELi3ELi3EEES10_S14_EEEELb1EvEEviT_T0_T1_T2_ENKUlvE12_clEvE10DIM_offset = 2;
.global .align 4 .u32 _ZZZZN12state_kernel11SympowStateILi1ELi1ELb1ELb1E22Chunk_state_bwd_traitsIN7cutlass6half_tELi32ELi4ELi52384ELi32ELi4ELi128E18State_chunk_traitsIS3_Li32ELi4ELi52384ELi32ELi4ELi128EEELb0EE14backpropColumnIN4cute6TensorINS9_10ViewEngineINS9_8smem_ptrIPiEEEENS9_6LayoutINS9_5tupleIJNS9_1CILi32EEENSI_ILi4EEEEEENSH_IJSK_NSI_ILi1EEEEEEEEEENSA_ISF_NSG_INSH_IJSJ_EEENSH_IJSM_EEEEEEENSA_INSB_INSC_IPS3_EEEENS9_14ComposedLayoutINS9_7SwizzleILi2ELi3ELi3EEENSI_ILi0EEENSG_INSH_IJNSI_ILi128EEESJ_EEENSH_IJSJ_SM_EEEEEEEEENSA_ISW_NSX_INSY_ILi3ELi3ELi3EEES10_S14_EEEELb1EvEEviT_T0_T1_T2_ENKUlvE13_clEvENKUlvE_clEvE3DIM = 7;
.global .align 4 .u32 _ZZZN12state_kernel11SympowStateILi1ELi1ELb1ELb1E22Chunk_state_bwd_traitsIN7cutlass6half_tELi32ELi4ELi52384ELi32ELi4ELi128E18State_chunk_traitsIS3_Li32ELi4ELi52384ELi32ELi4ELi128EEELb0EE14backpropColumnIN4cute6TensorINS9_10ViewEngineINS9_8smem_ptrIPiEEEENS9_6LayoutINS9_5tupleIJNS9_1CILi32EEENSI_ILi4EEEEEENSH_IJSK_NSI_ILi1EEEEEEEEEENSA_ISF_NSG_INSH_IJSJ_EEENSH_IJSM_EEEEEEENSA_INSB_INSC_IPS3_EEEENS9_14ComposedLayoutINS9_7SwizzleILi2ELi3ELi3EEENSI_ILi0EEENSG_INSH_IJNSI_ILi128EEESJ_EEENSH_IJSJ_SM_EEEEEEEEENSA_ISW_NSX_INSY_ILi3ELi3ELi3EEES10_S14_EEEELb1EvEEviT_T0_T1_T2_ENKUlvE13_clEvE10DIM_offset = 3;
.global .align 4 .u32 _ZZZN12state_kernel11SympowStateILi1ELi1ELb1ELb1E22Chunk_state_bwd_traitsIN7cutlass6half_tELi32ELi4ELi52384ELi32ELi4ELi128E18State_chunk_traitsIS3_Li32ELi4ELi52384ELi32ELi4ELi128EEELb0EE14backpropColumnIN4cute6TensorINS9_10ViewEngineINS9_8smem_ptrIPiEEEENS9_6LayoutINS9_5tupleIJNS9_1CILi32EEENSI_ILi4EEEEEENSH_IJSK_NSI_ILi1EEEEEEEEEENSA_ISF_NSG_INSH_IJSJ_EEENSH_IJSM_EEEEEEENSA_INSB_INSC_IPS3_EEEENS9_14ComposedLayoutINS9_7SwizzleILi2ELi3ELi3EEENSI_ILi0EEENSG_INSH_IJNSI_ILi128EEESJ_EEENSH_IJSJ_SM_EEEEEEEEENSA_ISW_NSX_INSY_ILi3ELi3ELi3EEES10_S14_EEEELb1EvEEviT_T0_T1_T2_ENKUlvE16_clEvE3DIM = 2;
.global .align 4 .u32 _ZZZN12state_kernel11SympowStateILi1ELi1ELb1ELb1E22Chunk_state_bwd_traitsIN7cutlass6half_tELi32ELi4ELi52384ELi32ELi4ELi128E18State_chunk_traitsIS3_Li32ELi4ELi52384ELi32ELi4ELi128EEELb0EE14backpropColumnIN4cute6TensorINS9_10ViewEngineINS9_8smem_ptrIPiEEEENS9_6LayoutINS9_5tupleIJNS9_1CILi32EEENSI_ILi4EEEEEENSH_IJSK_NSI_ILi1EEEEEEEEEENSA_ISF_NSG_INSH_IJSJ_EEENSH_IJSM_EEEEEEENSA_INSB_INSC_IPS3_EEEENS9_14ComposedLayoutINS9_7SwizzleILi2ELi3ELi3EEENSI_ILi0EEENSG_INSH_IJNSI_ILi128EEESJ_EEENSH_IJSJ_SM_EEEEEEEEENSA_ISW_NSX_INSY_ILi3ELi3ELi3EEES10_S14_EEEELb1EvEEviT_T0_T1_T2_ENKUlvE17_clEvE3DIM = 3;
.global .align 4 .u32 _ZZZN12state_kernel11SympowStateILi1ELi1ELb1ELb1E22Chunk_state_bwd_traitsIN7cutlass6half_tELi32ELi4ELi52384ELi32ELi4ELi128E18State_chunk_traitsIS3_Li32ELi4ELi52384ELi32ELi4ELi128EEELb0EE14backpropColumnIN4cute6TensorINS9_10ViewEngineINS9_8smem_ptrIPiEEEENS9_6LayoutINS9_5tupleIJNS9_1CILi32EEENSI_ILi4EEEEEENSH_IJSK_NSI_ILi1EEEEEEEEEENSA_ISF_NSG_INSH_IJSJ_EEENSH_IJSM_EEEEEEENSA_INSB_INSC_IPS3_EEEENS9_14ComposedLayoutINS9_7SwizzleILi2ELi3ELi3EEENSI_ILi0EEENSG_INSH_IJNSI_ILi128EEESJ_EEENSH_IJSJ_SM_EEEEEEEEENSA_ISW_NSX_INSY_ILi3ELi3ELi3EEES10_S14_EEEELb1EvEEviT_T0_T1_T2_ENKUlvE18_clEvE3DIM = 4;
.global .align 4 .u32 _ZZZN12state_kernel11SympowStateILi1ELi1ELb1ELb1E22Chunk_state_bwd_traitsIN7cutlass6half_tELi32ELi4ELi52384ELi32ELi4ELi128E18State_chunk_traitsIS3_Li32ELi4ELi52384ELi32ELi4ELi128EEELb0EE14backpropColumnIN4cute6TensorINS9_10ViewEngineINS9_8smem_ptrIPiEEEENS9_6LayoutINS9_5tupleIJNS9_1CILi32EEENSI_ILi4EEEEEENSH_IJSK_NSI_ILi1EEEEEEEEEENSA_ISF_NSG_INSH_IJSJ_EEENSH_IJSM_EEEEEEENSA_INSB_INSC_IPS3_EEEENS9_14ComposedLayoutINS9_7SwizzleILi2ELi3ELi3EEENSI_ILi0EEENSG_INSH_IJNSI_ILi128EEESJ_EEENSH_IJSJ_SM_EEEEEEEEENSA_ISW_NSX_INSY_ILi3ELi3ELi3EEES10_S14_EEEELb1EvEEviT_T0_T1_T2_ENKUlvE19_clEvE3DIM = 5;
.global .align 4 .u32 _ZZZZN12state_kernel11SympowStateILi1ELi1ELb1ELb1E22Chunk_state_bwd_traitsIN7cutlass6half_tELi32ELi4ELi52384ELi32ELi4ELi128E18State_chunk_traitsIS3_Li32ELi4ELi52384ELi32ELi4ELi128EEELb0EE14backpropColumnIN4cute6TensorINS9_10ViewEngineINS9_8smem_ptrIPiEEEENS9_6LayoutINS9_5tupleIJNS9_1CILi32EEENSI_ILi4EEEEEENSH_IJSK_NSI_ILi1EEEEEEEEEENSA_ISF_NSG_INSH_IJSJ_EEENSH_IJSM_EEEEEEENSA_INSB_INSC_IPS3_EEEENS9_14ComposedLayoutINS9_7SwizzleILi2ELi3ELi3EEENSI_ILi0EEENSG_INSH_IJNSI_ILi128EEESJ_EEENSH_IJSJ_SM_EEEEEEEEENSA_ISW_NSX_INSY_ILi3ELi3ELi3EEES10_S14_EEEELb1EvEEviT_T0_T1_T2_ENKUlvE20_clEvENKUlvE_clEvE3DIM = 6;
.global .align 4 .u32 _ZZZN12state_kernel11SympowStateILi1ELi1ELb1ELb1E22Chunk_state_bwd_traitsIN7cutlass6half_tELi32ELi4ELi52384ELi32ELi4ELi128E18State_chunk_traitsIS3_Li32ELi4ELi52384ELi32ELi4ELi128EEELb0EE14backpropColumnIN4cute6TensorINS9_10ViewEngineINS9_8smem_ptrIPiEEEENS9_6LayoutINS9_5tupleIJNS9_1CILi32EEENSI_ILi4EEEEEENSH_IJSK_NSI_ILi1EEEEEEEEEENSA_ISF_NSG_INSH_IJSJ_EEENSH_IJSM_EEEEEEENSA_INSB_INSC_IPS3_EEEENS9_14ComposedLayoutINS9_7SwizzleILi2ELi3ELi3EEENSI_ILi0EEENSG_INSH_IJNSI_ILi128EEESJ_EEENSH_IJSJ_SM_EEEEEEEEENSA_ISW_NSX_INSY_ILi3ELi3ELi3EEES10_S14_EEEELb1EvEEviT_T0_T1_T2_ENKUlvE20_clEvE10DIM_offset = 2;
.global .align 4 .u32 _ZZZZN12state_kernel11SympowStateILi1ELi1ELb1ELb1E22Chunk_state_bwd_traitsIN7cutlass6half_tELi32ELi4ELi52384ELi32ELi4ELi128E18State_chunk_traitsIS3_Li32ELi4ELi52384ELi32ELi4ELi128EEELb0EE14backpropColumnIN4cute6TensorINS9_10ViewEngineINS9_8smem_ptrIPiEEEENS9_6LayoutINS9_5tupleIJNS9_1CILi32EEENSI_ILi4EEEEEENSH_IJSK_NSI_ILi1EEEEEEEEEENSA_ISF_NSG_INSH_IJSJ_EEENSH_IJSM_EEEEEEENSA_INSB_INSC_IPS3_EEEENS9_14ComposedLayoutINS9_7SwizzleILi2ELi3ELi3EEENSI_ILi0EEENSG_INSH_IJNSI_ILi128EEESJ_EEENSH_IJSJ_SM_EEEEEEEEENSA_ISW_NSX_INSY_ILi3ELi3ELi3EEES10_S14_EEEELb1EvEEviT_T0_T1_T2_ENKUlvE21_clEvENKUlvE_clEvE3DIM = 7;
.global .align 4 .u32 _ZZZN12state_kernel11SympowStateILi1ELi1ELb1ELb1E22Chunk_state_bwd_traitsIN7cutlass6half_tELi32ELi4ELi52384ELi32ELi4ELi128E18State_chunk_traitsIS3_Li32ELi4ELi52384ELi32ELi4ELi128EEELb0EE14backpropColumnIN4cute6TensorINS9_10ViewEngineINS9_8smem_ptrIPiEEEENS9_6LayoutINS9_5tupleIJNS9_1CILi32EEENSI_ILi4EEEEEENSH_IJSK_NSI_ILi1EEEEEEEEEENSA_ISF_NSG_INSH_IJSJ_EEENSH_IJSM_EEEEEEENSA_INSB_INSC_IPS3_EEEENS9_14ComposedLayoutINS9_7SwizzleILi2ELi3ELi3EEENSI_ILi0EEENSG_INSH_IJNSI_ILi128EEESJ_EEENSH_IJSJ_SM_EEEEEEEEENSA_ISW_NSX_INSY_ILi3ELi3ELi3EEES10_S14_EEEELb1EvEEviT_T0_T1_T2_ENKUlvE21_clEvE10DIM_offset = 3;
.global .align 4 .u32 _ZZZN12state_kernel11SympowStateILi1ELi1ELb1ELb1E22Chunk_state_bwd_traitsIN7cutlass6half_tELi32ELi4ELi52384ELi32ELi4ELi128E18State_chunk_traitsIS3_Li32ELi4ELi52384ELi32ELi4ELi128EEELb0EE14backpropColumnIN4cute6TensorINS9_10ViewEngineINS9_8smem_ptrIPiEEEENS9_6LayoutINS9_5tupleIJNS9_1CILi32EEENSI_ILi4EEEEEENSH_IJSK_NSI_ILi1EEEEEEEEEENSA_ISF_NSG_INSH_IJSJ_EEENSH_IJSM_EEEEEEENSA_INSB_INSC_IPS3_EEEENS9_14ComposedLayoutINS9_7SwizzleILi2ELi3ELi3EEENSI_ILi0EEENSG_INSH_IJNSI_ILi128EEESJ_EEENSH_IJSJ_SM_EEEEEEEEENSA_ISW_NSX_INSY_ILi3ELi3ELi3EEES10_S14_EEEELb1EvEEviT_T0_T1_T2_ENKUlvE22_clEvE3DIM = 8;
...

This tells me that the meta-programming worked, but it’s probably true that the compiler decides there’s not enough headroom to store the cute::Tensor as registers. But for reasons I’m not fully aware of, compute-sanitizer shows a register headroom larger than the size of the cute::Tensor. It could be the case that in order to fully put the tensor in register, there might be more bytes needed than just the tensor themselves (for example, cute::Tensor might have internal indexing/layout variables, etc)?

Anyway, I think this issue is fixed, thanks @njuffa @Robert_Crovella @striker159 !