Cuda code performance

Hi to all. I’m new to this forum, cuda, and gpu programming.

My code is doing slotwise array computation, and until now, I ran this code sequentially on my cpu. I’m trying now to write it in CUDA, and run it in parallel on GPUs and I’m expecting to be quite faster.

I guess I’m doing something wrong, because it needed 56 microseconds to run the code on GPUs, and it takes around 7-8 microseconds to run it secuentially.

Please can someone take a look at my code, and check if there is some stupid/obvious mistake?

I will also attach my GPU specifications, as it is quite old FX.

Here is my deviceQuery output:

CUDA Device Query (Runtime API) version (CUDART static linking)

Detected 1 CUDA Capable device(s)

Device 0: “Quadro FX 580”
CUDA Driver Version / Runtime Version 6.5 / 6.5
CUDA Capability Major/Minor version number: 1.1
Total amount of global memory: 511 MBytes (536150016 bytes)
( 4) Multiprocessors, ( 8) CUDA Cores/MP: 32 CUDA Cores
GPU Clock rate: 1125 MHz (1.12 GHz)
Memory Clock rate: 800 Mhz
Memory Bus Width: 128-bit
Maximum Texture Dimension Size (x,y,z) 1D=(8192), 2D=(65536, 32768), 3D=(2048, 2048, 2048)
Maximum Layered 1D Texture Size, (num) layers 1D=(8192), 512 layers
Maximum Layered 2D Texture Size, (num) layers 2D=(8192, 8192), 512 layers
Total amount of constant memory: 65536 bytes
Total amount of shared memory per block: 16384 bytes
Total number of registers available per block: 8192
Warp size: 32
Maximum number of threads per multiprocessor: 768
Maximum number of threads per block: 512
Max dimension size of a thread block (x,y,z): (512, 512, 64)
Max dimension size of a grid size (x,y,z): (65535, 65535, 1)
Maximum memory pitch: 2147483647 bytes
Texture alignment: 256 bytes
Concurrent copy and kernel execution: Yes with 1 copy engine(s)
Run time limit on kernels: Yes
Integrated GPU sharing Host Memory: No
Support host page-locked memory mapping: Yes
Alignment requirement for Surfaces: Yes
Device has ECC support: Disabled
Device supports Unified Addressing (UVA): No
Device PCI Bus ID / PCI location ID: 15 / 0
Compute Mode:
< Default (multiple host threads can use ::cudaSetDevice() with device simultaneously) >

deviceQuery, CUDA Driver = CUDART, CUDA Driver Version = 6.5, CUDA Runtime Version = 6.5, NumDevs = 1, Device0 = Quadro FX 580
Result = PASS

…and here is my code:

#include <stdio.h>
#include <stdlib.h>

//current capability 1.1
//2.0 capability needed for printf (in kernel)
//1.3 capapility needed for using double precision

__global__
void dvdt(int n, int ngenes, int egenes, float *v, float *vext, float *voutput, float *h, float *m, float *E, float *T, float *R, float *D, float *lambda, float *tau)
{
  int i = blockIdx.x*blockDim.x + threadIdx.x;
  int k, k2, j, l_rule=1;
  float temp;
  
  if (i < n) {
    k = i % ngenes; 
    k2 = i % egenes;
    temp = h[k];
    temp += m[k];// * bcd.array[i];     // ap is nuclear index //Damjan: will not use this for now. We can add it later
    for( j = 0; j < egenes; j++ ) {
        temp += E[( k * egenes ) + j] * vext[(i-k2) + j]; //tu napravit da svaki thread ima jedan slot od v i jedan (odgovarajuci) slot od vext
    }
    for( j = 0; j < ngenes; j++ ) {
        temp += T[( k * ngenes ) + j] * v[(i-k) + j];
    }
    
    //voutput[i] = temp;
    
    //rsqrt(x) = 1/sqrt(x) using reciprocal sqrt is more efficient then using 1 / sqrt
    voutput[i] = -lambda[k] * v[i]  +  l_rule * R[k] * 0.5 * (1 + temp * rsqrt( 1 + temp * temp )); 
        
    if( n != ngenes )     // if n == ngenes we have one nuc therefore we don't have diffusion
    {                // then for multiple nuclei -> diffusion 
        if( i < ngenes ) 
        {   // first anterior-most nucleus 
            // v[i - inp->zyg.defs.ngenes] - v[i] == 0 
            voutput[i] += D[k] * ( v[i + ngenes] - v[i] );
        } 
        else if (i >= ngenes && i < n - ngenes) 
        {
            voutput[i] += D[k] * ( ( v[i - ngenes] - v[i] ) + ( v[i + ngenes] - v[i] ) );
        } 
        else 
        {   //i > n - ngenes && i <= n   // last: posterior-most nucleus 
            // v[i + inp->zyg.defs.ngenes] - v[i] == 0 
            voutput[i] += D[k] * ( v[i - ngenes] - v[i] );
        }
    }
    
  }
}

int main(void)
{
    int ngenes = 4;
    int egenes = 4;
    int nnucs = 53;
    int n = ngenes * nnucs;
    
  // Error code to check return values for CUDA calls
    cudaError_t err = cudaSuccess;

    int blockSize = 128;	

    float *v, *voutput, *d_x, *d_z, *d_y, *vext, *rh, *rm, *rE, *rT, *rR, *rD, *rlambda, *rtau; //*bot, 

    float R[4] = { 0, 0, 0, 0 };
    
    float T[16] = { -0.02875641,  0.03773355, -0.08696411,  0.02085833, 
                      0.07752513, -0.05127861,  0.08351538, -0.03693460,
                     -0.02272315, -0.05630814,  0.03840626,  0.07150631,
                      0.02276078, -0.02350780,  0.01856495, -0.06116937 
                   };        
    float E[16] = { -0.14654266,  0.01651517,  0.06442419, -0.05467585,
                     -0.12247616, -0.10550997,  0.10687539,  0.00000000, 
                      0.02896215, -0.07365022, -0.06193835,  0.00000000,
                      0.02710513, -0.04491538,  0.10493181,  0.00000000  
                   };
    float m[4] = { 0, 0, 0, 0 };
    float h[4] = { -2.5, -2.5, -2.5, -2.5 };
    float D[4] = { 0.237, 0.3, 0.115, 0.3 };
    float lambda[4] = { 13.76469068, 7.27890037, 11.63317492, 12.03105457 };
    float tau[4] = { 6, 6, 6, 6 };
    
    v = (float*)malloc(n * sizeof(float));
    vext = (float*)malloc(n * sizeof(float));
    voutput = ( float * ) malloc( n * sizeof( float ) );   
    //bot = ( double * ) calloc( n, sizeof( double ) );

    //we need parameters h, m, E, T, lambda, R, D
    //we need arrays v (main array that will be sent in parallel), v_ext

    // next loop does the rest of the equation (R, Ds and lambdas) 
        // store result in vdot[] 
  
  cudaMalloc(&d_x, n*sizeof(float)); 
  cudaMalloc(&d_y, n*sizeof(float));
  cudaMalloc(&d_z, n*sizeof(float));
  
  cudaMalloc(&rh, ngenes*sizeof(float));
  cudaMalloc(&rm, 4*sizeof(float));
  cudaMalloc(&rE, ngenes*egenes*sizeof(float));
  cudaMalloc(&rT, ngenes*ngenes*sizeof(float));
  cudaMalloc(&rD, ngenes*sizeof(float));  
  cudaMalloc(&rR, ngenes*sizeof(float));  
  cudaMalloc(&rlambda, ngenes*sizeof(float));
  cudaMalloc(&rtau, ngenes*sizeof(float));

  for (int i = 0; i < n; i++) {
    v[i] = i+3;
    vext[i] = i+5;
  }

  int error1 = cudaMemcpy(d_x, v, n*sizeof(float), cudaMemcpyHostToDevice);
  int error2 = cudaMemcpy(d_y, vext, n*sizeof(float), cudaMemcpyHostToDevice); 
  int error4 = cudaMemcpy(rh, h, ngenes*sizeof(float), cudaMemcpyHostToDevice);
  int error5 = cudaMemcpy(rm, m, ngenes*sizeof(float), cudaMemcpyHostToDevice);
  int error6 = cudaMemcpy(rE, E, ngenes*egenes*sizeof(float), cudaMemcpyHostToDevice);  
  int error7 = cudaMemcpy(rT, T, ngenes*ngenes*sizeof(float), cudaMemcpyHostToDevice);
  int error8 = cudaMemcpy(rR, R, ngenes*sizeof(float), cudaMemcpyHostToDevice);
  int error9 = cudaMemcpy(rD, D, ngenes*sizeof(float), cudaMemcpyHostToDevice);
  int error10 = cudaMemcpy(rlambda, lambda, ngenes*sizeof(float), cudaMemcpyHostToDevice);  
  int error11 = cudaMemcpy(rtau, tau, ngenes*sizeof(float), cudaMemcpyHostToDevice);

  if (error1 != 0) {
  	printf("error copying x to device: %d\n", error1);	
  }
  if (error2 != 0) {
  	printf("error copying y to device: %d\n", error2);	
  }
  /*if (error3 != 0) {
  	printf("error copying D to device: %d\n", error3);	
  }*/
  
  if (error4 != 0) {
  	printf("error copying h to device: %d\n", error4);	
  }
  if (error5 != 0) {
  	printf("error copying m to device: %d\n", error5);	
  }
  if (error6 != 0) {
  	printf("error copying E to device: %d\n", error6);	
  }
  if (error7 != 0) {
  	printf("error copying T to device: %d\n", error7);	
  }
  if (error8 != 0) {
  	printf("error copying R to device: %d\n", error8);	
  }
  if (error9 != 0) {
  	printf("error copying D to device: %d\n", error9);	
  }
  if (error10 != 0) {
  	printf("error copying lambda to device: %d\n", error10);	
  }
  if (error11 != 0) {
  	printf("error copying tau to device: %d\n", error11);	
  }
 
  // Perform SAXPY on 1M elements
  //dvdt<<<(n+(blockSize-1))/blockSize, blockSize>>>(n, 2.0, d_x, d_y);   //OLD //8 threads per block runs in parallel. If we have more threads they will wait for those to finish, as we have 8 threads per Multiprocessor, and 4 MPs on our gpu
  int gridSize = (int)ceil((float)n/blockSize);
  printf("gridsize = %d\n",gridSize);
            
  float responseTime; //result will be in milliseconds
  cudaEvent_t start; cudaEventCreate(&start); cudaEventRecord(start); cudaEventSynchronize(start);
  cudaEvent_t stop;  cudaEventCreate(&stop);
  
  dvdt<<<gridSize, blockSize>>>(n, ngenes, egenes, d_x, d_y, d_z, rh, rm, rE, rT, rR, rD, rlambda, rtau);

  cudaEventRecord(stop); cudaEventSynchronize(stop);
  cudaEventElapsedTime(&responseTime, start, stop); //responseTime = elapsed time
  
  printf("elapsed time = %lg us\n", (responseTime*1000)); //to get nanoseconds
  
  err = cudaGetLastError();

  if (err != cudaSuccess)
  {
      fprintf(stderr, "Failed to launch helloWorld kernel (error code %s)!\n", cudaGetErrorString(err));
      exit(EXIT_FAILURE);
  }

  int error12 = cudaMemcpy(voutput, d_z, n*sizeof(float), cudaMemcpyDeviceToHost);
  printf("%d\n", error12);	

// Release device memory
    cudaFree(d_x);
    cudaFree(d_y);
    cudaFree(d_z);
    
    cudaFree(rh);
    cudaFree(rm);
    cudaFree(rE);
    cudaFree(rT);
    cudaFree(rR);
    cudaFree(rD);
    cudaFree(rlambda);
    cudaFree(rtau);
    // Release host memory
    free(v);
    free(vext);
    free(voutput);
    
}

edit:
I edited the code a bit to correct few mistakes and apply some of your suggestions.

(1) The Quadro FX580 is a very slow GPU: 72 GFLOPS single-precision, 25 GB/sec memory bandwidth. It is quite possible that your CPU achieves higher performance, especially since it seems that the total data could fit into the last-level cache (meaning the effective bandwidth is much higher than 25 GB/sec). Your parallelism seems to be limited: unless I made a mistake the code uses only two thread blocks of 128 threads each. Even for a device with only 2 SMs like the Quadro FX580 this is too little to fill the machine in order to cover basic latencies, you would want to target at least more than one thousand threads.

(2) Your code appears to be memory bandwidth bound. It is a bit hard to see the access patterns based on cursory inspection of the code, but they don’t seem to follow a strict “base+thread_index” pattern which is necessary for best performance, especially on sm_1x devices like yours. I do not know how well the profiler works on such old hardware (the hardware support for the profiler was quite limited on sm_1x), but it may be worth a try to see whether it can point you at possible improvements.

(3) I don’t think it matters here, since the code appears memory bound, but for future reference this may be useful. Instead of

sqrtpart = sqrt( 1 + temp * temp );
(1 + temp / sqrtpart)

you would want to use

oosqrtpart = rsqrt( 1 + temp * temp );
(1 + temp * oosqrtpart)

(4) Also for future reference: Make sure to use the ‘f’ suffix on all floating-point constants used in single-precision computation. So use “0.5f” instead of “0.5”. An un-suffixed literal floating-point constant is double precision by default. This can inadvertently turn much of your computation into double-precision computation, which is often (much) slower than the desired single-precision computation. For an sm_11 target, which does not support double precision, the compiler will automatically demote to single-precision, but you should see a warning message about this demotion.

As an aside, your code as posted has some invalid indexing going on. You can test this by running it with cuda-memcheck.

Normally, the error checking you have is enough. But in this case, your invalid indexing steps just barely outside of the legal range:

else if (i >= ngenes && i <= n - ngenes) 
{
   voutput[i] = -lambda[k] * v[i]  +  l_rule * R[k] * 0.5 * (1 + temp / sqrtpart)  +  D[k] * ( ( v[i - ngenes] - v[i] ) + ( v[i + ngenes] - v[i] ) );
}

when i = (n-ngenes), the else-if condition will test true, and the body of the if statement will be executed. This will lead to requesting:

v[i + ngenes]

but when i = n-ngenes, then i+ngenes == n, and v[n] is outside of the allocated range of v (d_x). Ordinary GPU error checking (the “free” kind) will detect out-of-bounds accesses when they exceed a certain granularity, which is device and cuda-version dependent. cuda-memcheck, at the expense of reduced performance, will instrument your code and detect even these types of out-of-bounds accesses, which go just one element beyond the defined range.

You may be able to fix this just by changing the conditional test to:

else if (i >= ngenes && i < n - ngenes)

but I haven’t fully parsed your code.

I guess I'm doing something wrong, because it needed 56 nanoseconds to run the code on GPUs, and it takes around 7-8 nanoseconds to run it secuentially.

i could imagine that with such short running times simply the initialisation of the cuda kernel and the graphics card takes longer.

why do you bother about 8 nanoseconds?

edit:
cudaEventElapsedTime returns milliseconds, so you probably get 8 microseconds vs 56 microseconds. still, imo a work size of 16 is too small for the gpu. additionally there are virtually no computations, two times a loop until 4 and a lot of if/else (which gpu’s aren’t very good at).

Bits to note.

First comment:
Your units are wrong, the timings you are measuring are microseconds not nanoseconds.

Second comment:
Launching a kernel on the GPU takes a finite amount of time (single digit microseconds on linux, double digit microseconds on Win Vista/7/8).

Third:
Attempting a speedup on something with a runtime that small seems inefficient when you consider the locality of data issue. Since the data starts on the CPU, and ends on the CPU the time it takes to move the data across (Which you don’t currently include in your timings) will outweigh any potential speedup you could make. This would only really be viable if you had to do say a few thousand ‘iterations’ of the kernel before pulling the data back.

Fourth:
Branching is expensive. Your nested If/Else seem to have a lot of duplicated code with the ‘Else’ cases only adding a few extra calls. It is probably more efficient to do the common maths first, then ‘If/Else’ the additional bits.

Good luck, Tiomat

@adamce
I just changed that. I was talking about microseconds, of course. I would not bother about few microseconds, but this is the model that gets optimized by an optimization algorithm, and it is called iteratively, each time with different parameters (the v array), until some condition is met. It may take hours or days. As this part takes around 80% of the total computational time, we thought to speed it up by parallelizing it on GPUs. We are talking about 212 values of the array voutput that can be calculated in parallel. Taking in account that each value can be calculated at least as fast as with a normal CPU, and the fact that on our GPU we have 32 cores so we can calculate 32 values in parallel (at once), we thought that if programmed correctly, it could speed the whole algorithm very much. And of course, if that works, we would think of investing in a GPU with more cores so we could execute all 212 calculations in parallel. But first we need to make it work on small scale.
What do you mean by “work size of 16”?

@njuffa
Thanx for the rsqrt. Did not know that. For your comment about float values, actually in the futrure we will need double precision. But this will be in case that we buy a new GPU that allows that.

I’ll try to read a little bit about the GPU architecture and understand your and txbob comments, and come back to you.

Thank you all

Yes, that was a mistake. Thanx

As I mentioned earlier, you will need to process many more than 212 pieces of data in parallel to utilize the full performance of the GPU. Even if you are just looking at a prototype now, I would suggest scaling it up in size to get a better understanding what the performance on a full size production problem would be. If you go to the upper end of the GPU performance scale, target thread counts of 20,000 or more.

In production, I will not have more then 212 threads in parallel, because the input array (212 elements) of the next loop is based on the output of the current one, and I cannot send next 212 threads before the first 212 threads are computed.

So if this prototype takes in input the array vinput_loop1[212 elements] and produces voutput_loop1[212 elements] in the next execution we will have
vinput_loop2[212 elements] = f( voutput_loop1[212 elements] )

  1. Do you think that for this problem of 212 threads in parallel, doesn’t make sense to go on with gpus, as the improvement compared to CPUs will be very low, or zero, or negative?
  2. Does it matter the number of threads per block, or it counts only the total number of threads?

I understand that without profiling you cannot give me a precise answer, but I would appreciate any opinion.

Thank you

212 Threads is quite low so unless you can tweak other things you are unlikely (imo) to get a speedup for this task. That said there might be some ‘out of the box’ thinking that you can do.

  • Do you do any processing between the iterations on the cpu side? Can that be ported across so the device is doing a bigger ‘block’ of work?

  • Could you do multiple ‘runs’ of the entire pass in parallel? If you can set it up so you are running 16 versions of the program simultaneously (1 per block, with a single block running the 212 threads) then you will start to utilise the full gpu.

  • Following on from the multiple runs, if it is an optimization problem would you benefit from 16 (or more) parallel runs each starting at a different point in solution space? This might seem a reasonable way to parallelize the entire program. Could you even take it a step further and have each ‘thread’ do the work that your 212 threads currently do, and launch tens of thousands of these?

Good luck, sounds like it could be quite an interesting problem.

If I got it right, n = ngenes * nnucs = 4 * 4. So as far as I understand it, in one iteration it computes only 16 values.

also as Tiomat said, imo all of that should be made on the gpu, so that you don’t transport the data over the bus back and forth and relaunch the kernel.

but as everybody said, even a worksize of 212 is very small. maybe it’s possible to run several “experiments” in parallel on the gpu, meaning that you optimise several “models” with corresponding v arrays (or whatever parameters/data it needs) in parallel (i think that’s also what Tiomat said).

nnucs = 53 so n = 212

Following your comments I will think about two things:

  • moving the whole code to the kernel so I don’t need to move the memory back and forth
  • running more instances of the whole code in parallel, in order to get more threads

This means a lof of effort, as we are talking about moving tens of thousands of lines of code to the kernel, and I’m not sure if the final result will be satisfactory.

Anyway, the important thing is that I realized that putting 212 parallel threads on a GPU results in a slower computation then 212 sequential executions on the CPU, because of the latency due to moving memory to the kernel and back.

Thank you all for opening my eyes

ah, yes. my bad…

On a somewhat tangential note. What is a C++ portable method for declaring constants in difference precision? E.g., I want to write a kernel that is templated across precision, and want to use constants such that the compiler automatically uses 0.5 or 0.5f depending on whether double or float is used, respectively.

I think this would work:

template <typename T>

...

const T my_float_const = (T)0.5;

// for the purists:

const T my_float_const2 = static_cast<T>(0.5);