Significantly lower device memory bandwidth when using higher thread counts

Hi,

I am trying to optimize my 1D convolution kernel, code below.

template <int InputChannels, int InputLength, int Padding, int KernelSize, int ChannelsPerThread>
__global__ void conv1d(float *d_input, float *d_weight, float *d_bias, float *d_output)
{
    //define constants
    constexpr int SharedMemLength = constexpr_max(InputLength, KernelSize);
    const int blockId = blockIdx.x;
    const int tdIdx = threadIdx.x;
    const int laneIdx = threadIdx.x % warpSize;
    const int warpIdx = threadIdx.x / warpSize;

    const int input_accesses_per_thread = (InputChannels * InputLength)/(4 * blockDim.x); 
    const int weight_accesses_per_thread = (InputChannels * KernelSize)/(blockDim.x); 
    const int weight_offset = blockId * InputChannels * KernelSize;
    const int padded_input_length = InputLength + Padding * 2;

    //static mem allocations
    float regInput[padded_input_length*ChannelsPerThread] = {0};
    float regFilter[KernelSize*ChannelsPerThread];
    __shared__ float shared_mem[InputChannels * SharedMemLength];
    /*
    //load input from global memory into shared memory 
    for (int channelIndex = 0; channelIndex < input_accesses_per_thread; ++channelIndex){
        int td_offset = 4 * (channelIndex * blockDim.x + tdIdx); 
        int smem_offset = td_offset/32; 
        float4 data = *reinterpret_cast<float4*>(&d_input[td_offset]);
        shared_mem[td_offset + smem_offset + 0] = data.x; 
        shared_mem[td_offset + smem_offset + 1] = data.y; 
        shared_mem[td_offset + smem_offset + 2] = data.z; 
        shared_mem[td_offset + smem_offset + 3] = data.w; 
    }

    __syncthreads(); 

    //load input from shared memory into thread registers
    for (int channelIndex = 0; channelIndex < ChannelsPerThread; ++channelIndex){
        for (int colIndex = 0; colIndex < InputLength; ++colIndex){
            int regIndex = Padding + channelIndex * padded_input_length + colIndex;
            int sharedMemIndex = InputLength * (ChannelsPerThread * tdIdx + channelIndex) + colIndex;
            int smem_offset = sharedMemIndex/32; 
            regInput[regIndex] = shared_mem[sharedMemIndex + smem_offset];
        }
    }

    __syncthreads(); 
    */

    //load weights from global memory into shared memory 
    for (int channelIndex = 0; channelIndex < weight_accesses_per_thread; ++channelIndex){
        int td_offset = (channelIndex * blockDim.x) + tdIdx;
        shared_mem[td_offset] = d_weight[td_offset + weight_offset];
    }

    __syncthreads();

    //load weights from shared memory to thread registers
    for (int channelIndex = 0; channelIndex < ChannelsPerThread; ++channelIndex){
        for (int colIdx = 0; colIdx < KernelSize; ++colIdx){
            int regIndex = channelIndex * KernelSize + colIdx;
            int sharedMemIndex = KernelSize * (ChannelsPerThread * tdIdx + channelIndex) + colIdx;
            regFilter[regIndex] = shared_mem[sharedMemIndex];
        }
    }

    //outer loop iterates over each element in output vector
    #pragma unroll
    for (int tileIdx = 0; tileIdx < InputLength; ++tileIdx){
        float res = 0.0;
        
        //inner loop performs dot product over all kernel positions and accumulates results
        for(int dotIdx = 0; dotIdx < KernelSize; ++dotIdx){
            for(int channelIndex = 0; channelIndex < ChannelsPerThread; ++channelIndex){
                res += regInput[tileIdx + dotIdx + (channelIndex * padded_input_length)] * regFilter[dotIdx + (channelIndex * KernelSize)];
            }
        }

        shared_mem[tdIdx] = res; 

        __syncthreads(); 

        if (threadIdx.x < 128){
            shared_mem[threadIdx.x] += shared_mem[threadIdx.x + 128];
        }
        __syncthreads();

        if (threadIdx.x < 64){
            shared_mem[threadIdx.x] += shared_mem[threadIdx.x + 64];
        }
        __syncthreads();

        if (threadIdx.x < 32){
            shared_mem[threadIdx.x] += shared_mem[threadIdx.x + 32];
            __syncwarp();
            shared_mem[threadIdx.x] += shared_mem[threadIdx.x + 16];
            __syncwarp();
            shared_mem[threadIdx.x] += shared_mem[threadIdx.x + 8];
            __syncwarp();
            shared_mem[threadIdx.x] += shared_mem[threadIdx.x + 4];
            __syncwarp();
            shared_mem[threadIdx.x] += shared_mem[threadIdx.x + 2];
            __syncwarp();
            shared_mem[threadIdx.x] += shared_mem[threadIdx.x + 1];
            __syncwarp();
            if (threadIdx.x == 0){
                d_output[blockId * InputLength + tileIdx] = shared_mem[0] + d_bias[blockId];
            }
        } 
        __syncthreads();
    }
}

When I run this with ChannelPerThread = 8, 1024 blocks, and 256 threads per block, I end up with near peak global memory bandwidth during the loading step (~700 GB/s on my RTX 3090).

    for (int channelIndex = 0; channelIndex < weight_accesses_per_thread; ++channelIndex){
        int td_offset = (channelIndex * blockDim.x) + tdIdx;
        shared_mem[td_offset] = d_weight[td_offset + weight_offset];
    }

When I run the same code with ChannelsPerThread=2, 1024 blocks, and 1024 threads per block, I get significantly lower bandwidth (~330 GB/s)! The kernel with 256 threads actually has lower theoretical occupancy (~33%) than the one with higher thread count due to register pressure. I was under the impression that what impacts GMEM bandwidth is coalescing accesses/vectorized loads. Why does having more threads loading fewer elements (with accesses still coalesced) result in so much lower bandwidth than fewer threads performing more loads? I’ve included ncurep files here in case anyone sees anything interesting there - Conv1D Profiles - Google Drive

The CUDA profiler is usually a great help in analyzing these kind of situations, that is, differences in effective memory bandwidth between two code variants. I would suggest giving that a try.

Thanks for the quick reply. I did spend a while looking at the traces in NCU but I can’t pinpoint why there should be such a stark difference in bandwidth.

2-Channels-Per-Thread Data:


8-Channels-Per-Thread Data:


I am hitting the same number of sectors in both cases, and loading the same number of bytes. Looking at warp stalls in the source, both have around 5000 cycles showing under ‘long scoreboard’ at the lines associated with the global memory loads. The faster kernel actually has half the occupancy (33% vs. 66% for 2 channels per thread) and more shared memory bank conflicts due to the shmem ld/st patterns. I am at a bit of a loss as to why I lose half the bandwidth for a seemingly innocent change in work done per thread. I should note my line of questioning is mostly out of curiousity (this is for a pedagogical project). Ofcourse I would just use the 8-channel kernel in practice.