Type conversion throughput/latency

The CUDA guide says that the throughput for data conversion instructions is 16 per cycle per SM. I am trying to figure out if my calculations are correct for something like RTX 4090 conversion of INT32 to FP32:

16 *128 SMs * 2.2 GHz ~= 4.5 TOPS

For a tensor with 70 million elements, it would take roughly 15.5 us. This is ignoring any memory related considerations.

Do the calculations above look correct at the basic level? While profiling, I noticed that the type conversions were using the XU pipeline, which is stated to be fairly slow, but the calculations above paint a different picture.

16 per cycle per SM is slow compared to the throughput of 64 per cycle per SM for basic 32-bit integer arithmetic and 128 per cycle per SM for basic 32-bit floating-point operations.

If the int32->fp32 conversion contributes to a bottleneck in your code and the conversion does not require the full range of int32 to be supported, there may be faster ways to accomplish the conversion (see example below). Alternatively, you could look into replacing the int32 variable with an fp32 variable, for example if it is a simple counter.

/* -2**22 <= a < 2**22 */
__device__ float fast_int_to_float (int a)
{
    const float fmagic = (1 << 23) + (1 << 22);
    const int imagic = __float_as_int (fmagic);
    return __int_as_float (imagic + a) - fmagic;
}
3 Likes

Neat! The method provided by njuffa seems to be faster in some cases.

# cat t130.cu
/* -2**22 <= a < 2**22 */
__device__ float fast_int_to_float (int a)
{
    const float fmagic = (1 << 23) + (1 << 22);
    const int imagic = __float_as_int (fmagic);
    return __int_as_float (imagic + a) - fmagic;
}

__global__ void k1(int s, int e, float *r){

  float val = 0;
  for (int i = s+threadIdx.x; i < e; i++){
    float x = i;
    val += x;}
  r[threadIdx.x] = val;
}

__global__ void k2(int s, int e, float *r){

  float val = 0;
  for (int i = s+threadIdx.x; i < e; i++){
    float x = fast_int_to_float(i);
    val += x;}
  r[threadIdx.x] = val;
}

int main(){

  const int nBLK = 58*3;
  const int nTPB = 512;
  const int s = 101;
  const int e = 1048576;
  float *r;
  cudaMalloc(&r, nTPB*sizeof(*r));
  k1<<<nBLK, nTPB>>>(s,e,r);
  k2<<<nBLK, nTPB>>>(s,e,r);
  cudaDeviceSynchronize();
  k1<<<nBLK, nTPB>>>(s,e,r);
  cudaDeviceSynchronize();
  k2<<<nBLK, nTPB>>>(s,e,r);
  cudaDeviceSynchronize();
}


# nvcc -o t130 t130.cu -arch=sm_89 -Xptxas=-v
ptxas info    : 0 bytes gmem
ptxas info    : Compiling entry function '_Z2k2iiPf' for 'sm_89'
ptxas info    : Function properties for _Z2k2iiPf
    0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info    : Used 12 registers, 368 bytes cmem[0]
ptxas info    : Compiling entry function '_Z2k1iiPf' for 'sm_89'
ptxas info    : Function properties for _Z2k1iiPf
    0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info    : Used 12 registers, 368 bytes cmem[0]
root@hpe-dl385-gen10-005:~/bobc# compute-sanitizer ./t130
========= COMPUTE-SANITIZER
========= ERROR SUMMARY: 0 errors
# nsys nvprof --print-gpu-trace ./t130
WARNING: t130 and any of its children processes will be profiled.

Generating '/tmp/nsys-report-9c7e.qdstrm'
[1/3] [========================100%] report26.nsys-rep
[2/3] [========================100%] report26.sqlite
[3/3] Executing 'cuda_gpu_trace' stats report

 Start (ns)   Duration (ns)  CorrId  GrdX  GrdY  GrdZ  BlkX  BlkY  BlkZ  Reg/Trd  StcSMem (MB)  DymSMem (MB)  Bytes (MB)  Throughput (MBps)  SrcMemKd  DstMemKd     Device      Ctx  Strm          Name
 -----------  -------------  ------  ----  ----  ----  ----  ----  ----  -------  ------------  ------------  ----------  -----------------  --------  --------  -------------  ---  ----  ---------------------
 728,359,084     25,519,897     119   174     1     1   512     1     1       16         0.000         0.000                                                     NVIDIA L4 (0)    1     7  k1(int, int, float *)
 753,882,117     21,606,997     120   174     1     1   512     1     1       16         0.000         0.000                                                     NVIDIA L4 (0)    1     7  k2(int, int, float *)
 775,503,834     30,668,735     122   174     1     1   512     1     1       16         0.000         0.000                                                     NVIDIA L4 (0)    1     7  k1(int, int, float *)
 806,183,065     24,234,968     124   174     1     1   512     1     1       16         0.000         0.000                                                     NVIDIA L4 (0)    1     7  k2(int, int, float *)

#

I did not do careful verification of the SASS, but I did confirm that the k1 kernel has I2FP instructions and the k2 kernel does not.

Last I checked, the fast method has throughput that is no worse than using the dedicated int32->fp32 conversion instruction across all GPU architectures, and on various architectures it performs better. Obviously one would want to verify this before deploying such code in a given context.

For int32->fp64, one can handle the full range of int32 operands in analogous fashion:

__device__ double fast_int_to_double_do_not_use (int a)
{
    const double fmagic = (1ULL << 52) + (1ULL << 32); // 0x4330000100000000ULL
    const long long int imagic = __double_as_longlong (fmagic);
    return __longlong_as_double (imagic + a) - fmagic;
}

However, this requires a 64-bit integer addition which requires two instructions with carry propagation between them, lengthening the dependency chain. The following approach based on manual sign extension is the most efficient (the second stage of the sign extension is baked into the subtraction of the “magic constant”):

__device__ double fast_int_to_double (int a)
{
    double t = __hiloint2double (0x43300000, 0x80000000 ^ a);
    return t - __hiloint2double (0x43300000, 0x80000000);
}

[Later:] Checking with Compiler Explorer, the code generated for sm_90 looks unusual, suggesting that the first variant of int32->fp64 conversion might perform better on that platform. On all other architectures the SASS generated for the second variant looks like I expected: LOP, MOV, DADD with constant bank operand.

Again, best to try it in any particular context where one might want to employ this.

The conversion is in the context of rescaling INT32 accumulators for an integer GEMM. The full range may be needed in some cases but this is a good optimization to incorporate.

Thanks for confirming my understanding of the perf impact.

This topic was automatically closed 14 days after the last reply. New replies are no longer allowed.