Fast CUDA implementation for calculating cross-norm-distance of two matrices

Assume I have two matrices A and B of the shape n * m. I want to calculate an n * n matrix C, such that C[i,j]=||A[i]-B[j]|| where || || is a distance measure for two vectors, such as the infinity-norm distance (max absolote distance) or the Manhattan distance (sum of absolute distance). I have write a CUDA kernel using the approach based on shared memory described in CUDA best practice guide as follows:

#define TILE_DIM 16
// Assuming N and M are multiples of TILE_DIM for simplicity
__global__ void cross_infinity_norm(float *a, float* b, float *c, int N, int M)
{
    __shared__ float aTile[TILE_DIM][TILE_DIM], bTile[TILE_DIM][TILE_DIM + 2];
    int a_row_block = blockIdx.y * TILE_DIM + threadIdx.y;
    int b_row_block = blockIdx.x * TILE_DIM + threadIdx.x;
    float ans = 0.0f;
    for (int i = 0; i < M; i += TILE_DIM) {
         aTile[threadIdx.y][threadIdx.x] = a[(blockIdx.y * TILE_DIM + threadIdx.y) * N + i + threadIdx.x];
         bTile[threadIdx.x][threadIdx.y] = b[(blockIdx.x * TILE_DIM + threadIdx.y) * N + i + threadIdx.x];
         __syncthreads();
         #pragma unroll
         for (int j = 0; j < TILE_DIM; j++)
             ans = max(ans, abs(aTile[threadIdx.y][j] - bTile[j][threadIdx.x]));
         __syncthreads();
     }
     c[a_row_block * N + b_row_block] = ans;
}

In my opinion, this program is efficient since it has no bank conflict of shared memory and all global memory accesses are coalesced. However, when I profile this program using NVIDIA Nsight Compute (setting N = M = 2048), it shows the following output with two warnings:

----------------------------------------------------------------------
DRAM Frequency                             cycle/nsecond                   9.49
SM Frequency                               cycle/nsecond                   1.39
Elapsed Cycles                                     cycle             10,142,211
Memory [%]                                             %                  96.97
SOL DRAM                                               %                  32.92
Duration                                         msecond                   7.27
SOL L1/TEX Cache                                       %                  97.14
SOL L2 Cache                                           %                  28.21
SM Active Cycles                                   cycle          10,120,031.15
SM [%]                                                 %                  96.97
----------------------------------------------------------------------
OK    The kernel is utilizing greater than 80.0% of the available compute or memory performance of the device. To
      further improve performance, work will likely need to be shifted from the most utilized to another unit.
      Start by analyzing workloads in the Compute Workload Analysis section.

OK    The ratio of peak float (fp32) to double (fp64) performance on this device is 64:1. The kernel achieved 4% of
      this device's fp32 peak performance and 0% of its fp64 peak performance.

Section: Compute Workload Analysis
----------------------------------------------------------------------
Executed Ipc Active                           inst/cycle                   1.41
Executed Ipc Elapsed                          inst/cycle                   1.40
Issue Slots Busy                                       %                  35.17
Issued Ipc Active                             inst/cycle                   1.41
SM Busy                                                %                  35.17
----------------------------------------------------------------------

Section: Memory Workload Analysis
----------------------------------------------------------------------
Memory Throughput                           Gbyte/second                 300.02
Mem Busy                                               %                  62.33
Max Bandwidth                                          %                  96.97
L1/TEX Hit Rate                                        %                   0.27
L2 Compression Success Rate                            %                      0
L2 Compression Ratio                                                          0
L2 Hit Rate                                            %                  49.69
Mem Pipes Busy                                         %                  96.97
----------------------------------------------------------------------

Section: Scheduler Statistics
----------------------------------------------------------------------
One or More Eligible                                   %                  35.18
Issued Warp Per Scheduler                                                  0.35
No Eligible                                            %                  64.82
Active Warps Per Scheduler                          warp                  11.86
Eligible Warps Per Scheduler                        warp                   1.64
----------------------------------------------------------------------
WRN   Every scheduler is capable of issuing one instruction per cycle, but for this kernel each scheduler only
      issues an instruction every 2.8 cycles. This might leave hardware resources underutilized and may lead to
      less optimal performance. Out of the maximum of 12 warps per scheduler, this kernel allocates an average of
      11.86 active warps per scheduler, but only an average of 1.64 warps were eligible per cycle. Eligible warps
      are the subset of active warps that are ready to issue their next instruction. Every cycle with no eligible
      warp results in no instruction being issued and the issue slot remains unused. To increase the number of
      eligible warps either increase the number of active warps or reduce the time the active warps are stalled.

Section: Warp State Statistics
----------------------------------------------------------------------
Warp Cycles Per Issued Instruction                 cycle                  33.72
Warp Cycles Per Executed Instruction               cycle                  33.72
Avg. Active Threads Per Warp                                                 32
Avg. Not Predicated Off Threads Per Warp                                  31.98
----------------------------------------------------------------------
WRN   On average each warp of this kernel spends 15.9 cycles being stalled waiting for the MIO instruction queue to
      be not full. This represents about 47.2% of the total average of 33.7 cycles between issuing two
      instructions. This stall reason is high in cases of extreme utilization of the MIO pipelines, which include
      special math instructions, dynamic branches, as well as shared memory instructions.

Section: Instruction Statistics
----------------------------------------------------------------------
Avg. Executed Instructions Per Scheduler            inst           3,559,324.10
Executed Instructions                               inst          1,167,458,304
Avg. Issued Instructions Per Scheduler              inst           3,559,430.27
Issued Instructions                                 inst          1,167,493,130
----------------------------------------------------------------------

Section: Launch Statistics
----------------------------------------------------------------------
Block Size                                                                  256
Grid Size                                                                16,384
Registers Per Thread                     register/thread                     36
Shared Memory Configuration Size                   Kbyte                  32.77
Driver Shared Memory Per Block               Kbyte/block                   1.02
Dynamic Shared Memory Per Block               byte/block                      0
Static Shared Memory Per Block               Kbyte/block                   2.18
Threads                                           thread              4,194,304
Waves Per SM                                                              33.30
----------------------------------------------------------------------

Section: Occupancy
----------------------------------------------------------------------
Block Limit SM                                     block                     16
Block Limit Registers                              block                      6
Block Limit Shared Mem                             block                     32
Block Limit Warps                                  block                      6
Theoretical Active Warps per SM                     warp                     48
Theoretical Occupancy                                  %                    100
Achieved Occupancy                                     %                  98.84
Achieved Active Warps Per SM                        warp                  47.44
----------------------------------------------------------------------

(My GPU is NVIDIA RTX 3090.)

The test program is as follows:

int main() {
    int N = 2048, M = 2048;
    float *output = new float[N * N];
    float *input = new float[N * M];
    float *weight = new float[N * M];
    for (int i = 0; i < N * M; i++)
        input[i] = ((float)rand() - 0.5) / RAND_MAX;
    for (int i = 0; i < N * M; i++)
        weight[i] = ((float)rand() - 0.5) / RAND_MAX;
    float *input_cuda, *weight_cuda, *output_cuda;
    cudaMalloc(&output_cuda, N * N * sizeof(float));
    cudaMalloc(&input_cuda, N * M * sizeof(float));
    cudaMalloc(&weight_cuda, N * M * sizeof(float));
    cudaMemcpy(input_cuda, input, N * M * sizeof(float), cudaMemcpyHostToDevice);
    cudaMemcpy(weight_cuda, weight, N * M * sizeof(float), cudaMemcpyHostToDevice);

    for (int i = 0; i < 2; i++) {
        dim3 dimBlock(TILE_DIM, TILE_DIM);
        dim3 dimGrid(N / TILE_DIM, N / TILE_DIM);
        cross_infinity_norm<<<dimGrid, dimBlock>>>(input_cuda, weight_cuda, output_cuda, N, M);
    }
    cudaDeviceSynchronize();

    cudaFree(output_cuda);
    cudaFree(input_cuda);
    cudaFree(weight_cuda);
    delete[] output;
    delete[] input;
    delete[] weight;
}

In particular, I am confused by the profiler output that the program should improve the Compute Workload, but the above warning says the stall reason is high in cases of extreme utilization of the MIO pipelines, which include special math instructions, dynamic branches, as well as shared memory instructions. My questions are:

  • How can I understand the profiler output? In particular, what does the two warnings mean? (This program does not have special math instructions or dynamic branches.)
  • Is there a faster implementation? I really want an extremely fast implementation since this function will be run thousands of times. Any tricks or achitecture dependent optimizations (NVIDIA RTX 3090 with Compute Capability 8.6 in my case) can be used.