Optimization opportunity for large vector access

Hello everyone,

I am facing a challenge where I need to optimize a kernel for its memory throughput.
Long story short, suppose we need to calculate function y = f(x) in which x is a vector of certain length M, for a batch of different x. The x vectors are generated via a external neural network and stored in global memory linearly. Currently the best performance is achieved by loading the array into shared memory cache with coalesced access (simple example below), however, when the vector length M increases to a certain point, the total allocation of cache is over the shared memory limit (48k) and it fails to compile. If I directly access x from global memory without changing anything else, the performance is too poor to be practical. Is there any trick I can use to keep the same performance (mainly memory throughput I guess) with when dealing with big M?

The simplest kernel that demonstrate the situation is:

template <int M>
__global__ void Kernel(float* output, float* x_head)
{
    __shared__ float x_cache[32][M];
    constexpr auto WORK_PER_THREAD = (M + 31) / 32;
    for (int i = 0;i<32;i++)
    {
        float* x = x_head + M * (blockIdx.x * blockDim.x + i);
        for(int j = 0;j<WORK_PER_THREAD;j++)
        {
            x_cache[i][j * 32 + threadIdx.x] = x_head[j * 32 + threadIdx.x];
        }
    }
    
    float y = f(x_cache[threadIdx.x]);
    output[blockIdx.x * blockDim.x + threadIdx.x] = y;      
}

Thank you!

Can you show function f ?

Depending on the GPU architecture, there is more thatn 48k shared memory available. See table 15 in Programming Guide :: CUDA Toolkit Documentation

Hi thank you for the reference, yes I should totally try to raise the compute capability setting of the compiler since I have a 2080ti. Yet I am afraid that for some cases even 8.0’s 163kb cannot satisfy the requirement (M can be sometimes 2048, then 256k).

f(x) is rather complicated in my application in which x is accessed both randomly and linearly, you can suppose it’s some kind of sampling & possibility evaluation process. Therefore I am afraid it’s not possible to show the detail of f(x) here.

How about using more than 1 thread per x ? For example, one warp / 32 threads? WIth M = 2048, 6 warps with 8K shared memory each could fit on the GPU. Of course, that may require some refactoring of f.

Well it makes sense, but then we need to launch several times more threads. Intuitively that will make the application several times slower?

If you are worried about scheduling overhead, you could simply keep the same number of threads, but process multiple x per warp.

Keep in mind that ideally the work per thread is reduced by factor 32 as well when 1 warp per x is used. I would just try it out.

I think there might exist some misunderstanding. Let me confirm, in my code above, each warp calculates 32 x vectors instead of 1, right? But you seem to indicate that each warp only calculates 1. This seems very slow…

Yes, you currently use 1 x per thread / 32 x per warp. I suggest using 1 x per warp.

Here is some toy example with 32 threads per x. (I fixed your kernel in this code, you were missing synchronization before and after accessing shared memory. Your code also did not use pointer x, but always read from the base value x_head).

M = 128, 256000 vectors
Timings on my machine with Titan Xp:
Timing 1 thread per x: 0.00134403s (8000 blocks, 32 threads each)
Timing 32 threads per x: 0.000329215s (480 blocks, 128 threads each)

//nvcc -arch=sm_61 -lineinfo -g -O3 -std=c++17
#include <cassert>
#include <chrono>
#include <iostream>

#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>

#include <thrust/fill.h>
#include <thrust/execution_policy.h>

namespace cg = cooperative_groups;


__global__
void printkernel(float* C1, float* C2){
    for(int i = 0; i < 10; i++){
        printf("%f ", C1[i]);
    }
    printf("\n");
    for(int i = 0; i < 10; i++){
        printf("%f ", C2[i]);
    }
    printf("\n");
}


template<int M>
__device__
float f(const float* x){
    float result = 0;
    for(int i = 0; i < M; i++){
        result += expf(sin(x[i]));
    }
    return result;
}

template<int M, class Group>
__device__
float f(Group group, const float* x){
    float result = 0;
    for(int i = group.thread_rank(); i < M; i += group.size()){
        result += expf(sin(x[i]));
    }
    result = cg::reduce(group, result, cg::plus<float>{});
    return result;
}


template<int M>
__global__
void kernel1(const float* x_head, float* output){
    __shared__ float x_cache[32][M];
    constexpr auto WORK_PER_THREAD = (M + 31) / 32;
    for (int i = 0;i<32;i++)
    {
        const float* x = x_head + M * (blockIdx.x * blockDim.x + i);
        for(int j = 0;j<WORK_PER_THREAD;j++)
        {
            x_cache[i][j * 32 + threadIdx.x] = x[j * 32 + threadIdx.x];
        }
    }
    __syncthreads();
    
    float y = f<M>(&x_cache[threadIdx.x][0]);
    output[blockIdx.x * blockDim.x + threadIdx.x] = y;
}


template<int M, int blocksize, int groupsize>
__global__
void kernel4(const float* x_head, float* output, int numX){
    constexpr int numGroupsPerBlock = blocksize / groupsize;

    auto group = cg::tiled_partition<groupsize>(cg::this_thread_block());
    const int numGroupsInGrid = (blockDim.x * gridDim.x) / groupsize;
    const int groupIdInGrid = (threadIdx.x + blockIdx.x * blockDim.x) / groupsize;

    
    //only need 1 x array smem per group
    __shared__ float x_cache[numGroupsPerBlock][M];
    
    //each group processes processes 1 x / each block processes numGroupsPerBlock x
    for(int xIndex = groupIdInGrid; xIndex < numX; xIndex += numGroupsInGrid){
        //load x vector of group to smem
        const float* x = x_head + M * xIndex;
        for(int j = group.thread_rank(); j < M; j += group.size()){
            x_cache[group.meta_group_rank()][j] = x[j];
        }
        //wait for shared memory
        group.sync();

        //process x vector with group
        float y = f<M>(group, &x_cache[group.meta_group_rank()][0]);
        if(group.thread_rank() == 0){
            output[xIndex] = y;
        }
        //wait untill shared memory can be reused
        group.sync();
    }
}



int main(){
    constexpr size_t numColumns = 128;
    size_t numRows = 256000;

    float* A = nullptr; 
    float* C1 = nullptr;
    float* C2 = nullptr;
    cudaMalloc(&A, sizeof(float) * numColumns * numRows);
    cudaMalloc(&C1, sizeof(float) * numRows);
    cudaMalloc(&C2, sizeof(float) * numRows);

    thrust::fill(
        thrust::device,
        A,
        A + numColumns * numRows,
        0.01f
    );

    int deviceId = 0;
    int numSMs = 0;
    cudaGetDevice(&deviceId);
    cudaDeviceGetAttribute(&numSMs, cudaDevAttrMultiProcessorCount, deviceId);

    //32 threads per block. 1 thread per x
    dim3 block1(32);
    dim3 grid1((numRows + 31) / 32);
    auto t1 = std::chrono::system_clock::now();
    kernel1<numColumns><<<grid1, block1>>>(A, C1);
    cudaDeviceSynchronize();
    auto t2 = std::chrono::system_clock::now();
    std::cout << "Timing 1 thread per x: " << std::chrono::duration<double>(t2 - t1).count() << "s\n";

    //128 threads per block, 1 group per x
    constexpr int blocksize4 = 128;
    constexpr int groupsize4 = 32;
    constexpr int xPerBlock4 = blocksize4 / groupsize4;

    int maxBlocksPerSM4 = 0;   
    cudaOccupancyMaxActiveBlocksPerMultiprocessor(
        &maxBlocksPerSM4,
        kernel4<numColumns, blocksize4, groupsize4>,
        blocksize4, 
        0
    );
    dim3 block4(blocksize4);
    dim3 grid4(std::min(size_t(maxBlocksPerSM4) * numSMs, (numRows + xPerBlock4 - 1) / xPerBlock4));
    auto t3 = std::chrono::system_clock::now();
    kernel4<numColumns, blocksize4, groupsize4><<<grid4, block4>>>(A, C2, numRows);
    cudaDeviceSynchronize();
    auto t4 = std::chrono::system_clock::now();
    std::cout << "Timing 32 threads per x: " << std::chrono::duration<double>(t4 - t3).count() << "s\n";

    printkernel<<<1,1>>>(C1, C2);
    cudaDeviceSynchronize();
}
1 Like

Hi yes this seems to be an interesting approach. So far I’ve never tried to use multiple threads in a warp for a single process, maybe it will work for my purpose! Thank you for the example I might need some time to implement it and see what’s gonna happen.