Optimizing CUDA kernel used with Pytorch

I’m working on a research paper and I have to implement a CUDA kernel for Pytorch. I have very basic knowledge of CUDA programming and I managed to implement a working example of the idea but I feel it’s far from being correctly optimized.

The kernel should take as input three tensors, the input tensor and two offset tensors. The input tensor will have shape [T,B,R] and the offset tensors will have shape [T,B] each. The idea is that the kernel will convert the relative offset tensors into absolute offsets and then use these offsets to index multiple times from the input tensor and do a calculation between these values in order to store the result into the output tensor. Below is the kernel

template <typename scalar_t>
__global__ void EncoderKernel(const at::PackedTensorAccessor32<scalar_t, 3, torch::RestrictPtrTraits> input, 
    const at::PackedTensorAccessor32<scalar_t, 2, torch::RestrictPtrTraits> offset_left,
    const at::PackedTensorAccessor32<scalar_t, 2, torch::RestrictPtrTraits> offset_right,
    const scalar_t __restrict__ max_left,
    const scalar_t __restrict__ max_right,
    at::PackedTensorAccessor32<scalar_t, 3, torch::RestrictPtrTraits> output){

    const int length = input.size(0);
    const int batchSize = input.size(1);
    const int r_dim = input.size(2);

    const int index = __umul24(blockDim.x, blockIdx.x) + threadIdx.x;
    const int rIdx = index % r_dim;
    const int batchIdx = (index / r_dim) % batchSize;
    const int tokenIdx = (index / r_dim) / batchSize;


    if (batchIdx < batchSize and tokenIdx < length and rIdx < r_dim) {
        const scalar_t left_off = static_cast<scalar_t>(offset_left[tokenIdx][batchIdx]);
        const scalar_t right_off = static_cast<scalar_t>(offset_right[tokenIdx][batchIdx]);

        const scalar_t true_left_off = clamp(tokenIdx - left_off * max_left, static_cast<scalar_t>(0.0), static_cast<scalar_t>(length-1));
        const scalar_t true_right_off = clamp(tokenIdx + right_off * max_right, static_cast<scalar_t>(0.0), static_cast<scalar_t>(length-1));

        const int32_t ind_floor_left = clamp(static_cast<int32_t>(floor(true_left_off)), 0, length-1);
        const int32_t ind_ceil_left = clamp(static_cast<int32_t>(ceil(true_left_off)), 0, length-1);

        const int32_t ind_floor_right = clamp(static_cast<int32_t>(floor(true_right_off)), 0, length-1);
        const int32_t ind_ceil_right = clamp(static_cast<int32_t>(ceil(true_right_off)), 0, length-1);

        const scalar_t alpha_left = ind_ceil_left - true_left_off;
       	const scalar_t alpha_right = true_right_off - ind_floor_right;


        const scalar_t S_output = ((1.0 - alpha_right)*input[ind_floor_right][batchIdx][rIdx] + 
        	alpha_right*input[ind_ceil_right][batchIdx][rIdx]) - 
        	(alpha_left*((ind_floor_left-1 < 0)?static_cast<scalar_t>(0.0):input[ind_floor_left-1][batchIdx][rIdx]) + 
        	(1.0 - alpha_left)*((ind_ceil_left-1 < 0)?static_cast<scalar_t>(0.0):input[ind_ceil_left-1][batchIdx][rIdx]));

        output[tokenIdx][batchIdx][rIdx] = S_output;
    }
}

and this is the code where I launch the kernel

void Encoder(at::Tensor & input,
    at::Tensor & offset_left, at::Tensor & offset_right, 
    int max_left, int max_right,
    at::Tensor & output) {

    const int length = input.size(0);
    const int batchSize = input.size(1);
    const int r_dim = input.size(2);


    const dim3 blockSize(128);
    const dim3 gridSize((length*batchSize*r_dim + blockSize.x - 1) / blockSize.x);

    AT_DISPATCH_FLOATING_TYPES_AND_HALF(output.scalar_type(), "gpu::Encoder", ([&] {
        
        auto inputAcsr = input.packed_accessor32<scalar_t, 3, torch::RestrictPtrTraits>();
        auto offsetLeftAcsr = offset_left.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>();
        auto offsetRightAcsr = offset_right.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>();
        auto outputAcsr = output.packed_accessor32<scalar_t, 3, torch::RestrictPtrTraits>();

        scalar_t max_left_f = static_cast<scalar_t>(max_left);
        scalar_t max_right_f = static_cast<scalar_t>(max_right);

        EncoderKernel<scalar_t><<<gridSize, blockSize, 0, at::cuda::getCurrentCUDAStream()>>> (
                inputAcsr, offsetLeftAcsr, offsetRightAcsr, max_left_f, max_right_f, outputAcsr);

    }));

    AT_CUDA_CHECK(cudaGetLastError());
}

One idea I have to optimize it is to split the offset conversation and the offset indexing into two kernels that will be called one after the other but I don’t really know if this will help. Any idea would be very appreciated.