matrix multiply reduction

The 192/384 thread numbers for latency hiding are per multiprocessor not per block.

Sorry, I have no idea what you are asking about.

Look at this line of code:

val += (im[2*nact*i+tid] * x[i]) + (im[(2*i+1)*nact+tid] * y[i]);

How many floating point ops does it contain? How many integer ops does it contain? Could you imagine a way to reduce the number of integer ops? Might it be faster if you did so?

The 192/384 thread numbers for latency hiding are per multiprocessor not per block.

Sorry, I have no idea what you are asking about.

Look at this line of code:

val += (im[2*nact*i+tid] * x[i]) + (im[(2*i+1)*nact+tid] * y[i]);

How many floating point ops does it contain? How many integer ops does it contain? Could you imagine a way to reduce the number of integer ops? Might it be faster if you did so?

wow you are quick !!!

Using shared memory for x and y resulted in bringing down the execution time to 0.17747 ms on the Fermi C2050. :)

used variables to represent index calculations and that brought the time down to 0.16852 ms on the Fermi C2050. When I average the kernel execution over 2000 runs, I get

“Total computation time on device = 5083.000000 usec averaged 2000 times = 2.541500 usec”

/* get_clock function */

        double get_clock()

        {

	    struct timeval tv;

	    gettimeofday(&tv, NULL);

	    return (double)tv.tv_usec + 1000000*tv.tv_sec;

        }

	double t4 = get_clock();

	for(i = 0; i < 2000; ++i)

	{

		recon_reduce<<< num_blocks, threads>>>(mtx_imd, xsh_d, ysh_d, fac_d); //num_blocks =42, threads =32

	}

	double dt4 = get_clock() - t4;

	printf("Total computation time on device = %f usec averaged 2000 times = %f usec\n", dt2, dt2/2000);

quoting the MY_GEMV example:

__global__ static void MY_GEMV(float* A, float* result)

{

	__shared__ float smem_tile[32][32+1];

	unsigned int t_x = threadIdx.x + __mul24(__mul24(blockIdx.x,32), N+padding );//blockIdx.x*32*(N+64);

	

	float sum = 0.0f;

#pragma unroll

	for(int iters = 0; iters < N/32; iters++)

	{

#pragma unroll

		for(int i = 0; i < 32; i++)

			smem_tile[i][threadIdx.x] = A[t_x + i*(N+padding) + iters*32];

#pragma unroll

		for(int i = 0; i < 32; i++)

			sum+= smem_tile[threadIdx.x][i]*d_x[i + iters*32];

	}

	result[threadIdx.x + blockIdx.x*32] = sum;

}

wow you are quick !!!

Using shared memory for x and y resulted in bringing down the execution time to 0.17747 ms on the Fermi C2050. :)

used variables to represent index calculations and that brought the time down to 0.16852 ms on the Fermi C2050. When I average the kernel execution over 2000 runs, I get

“Total computation time on device = 5083.000000 usec averaged 2000 times = 2.541500 usec”

/* get_clock function */

        double get_clock()

        {

	    struct timeval tv;

	    gettimeofday(&tv, NULL);

	    return (double)tv.tv_usec + 1000000*tv.tv_sec;

        }

	double t4 = get_clock();

	for(i = 0; i < 2000; ++i)

	{

		recon_reduce<<< num_blocks, threads>>>(mtx_imd, xsh_d, ysh_d, fac_d); //num_blocks =42, threads =32

	}

	double dt4 = get_clock() - t4;

	printf("Total computation time on device = %f usec averaged 2000 times = %f usec\n", dt2, dt2/2000);

quoting the MY_GEMV example:

__global__ static void MY_GEMV(float* A, float* result)

{

	__shared__ float smem_tile[32][32+1];

	unsigned int t_x = threadIdx.x + __mul24(__mul24(blockIdx.x,32), N+padding );//blockIdx.x*32*(N+64);

	

	float sum = 0.0f;

#pragma unroll

	for(int iters = 0; iters < N/32; iters++)

	{

#pragma unroll

		for(int i = 0; i < 32; i++)

			smem_tile[i][threadIdx.x] = A[t_x + i*(N+padding) + iters*32];

#pragma unroll

		for(int i = 0; i < 32; i++)

			sum+= smem_tile[threadIdx.x][i]*d_x[i + iters*32];

	}

	result[threadIdx.x + blockIdx.x*32] = sum;

}

Sorry, but I continue to plead ignorance. My memory isn’t perfect, but that looks nothing like any code I can recall writing…

Sorry, but I continue to plead ignorance. My memory isn’t perfect, but that looks nothing like any code I can recall writing…

I am trying to use this kernel for different matrix sizes.

My kernel looks like this:

int num_blocks = (nac+threads-1)/threads; //nac = 1900, threads = 32

recon_reduce<<< num_blocks, threads>>>(mtxa_imd, xsha_d, ysha_d, fac_d);

const int threads =32; 

const int iters =150; //change iters depending on matrix size

const int nact=1900;

__global__ void recon_reduce(float* im, float* x, float* y, float* out)

{

	int tid = threadIdx.x + blockIdx.x*blockDim.x;

	

	float val = 0;

	int index1 = 0;

	int index2 = 0;

	__shared__ float x_val[iters];

	__shared__ float y_val[iters];

	

	x_val[tid] = x[tid];

	y_val[tid] = y[tid];

	

#pragma unroll

		for(int i = 0; i < iters; i++)

		{

			index1 = 2*nact*i+tid;

			index2 = (2*i+1)*nact+tid;

			val += (im[index1] * x_val[i]) + (im[index2] * y_val[i]);

		}		

		out[tid] = val;

}

Observations:

  1. For iters = 150, the cuda_profiler gives the GPU time as 8.64 and CPU time as 34 with occupancy as 0.167. The cudaEventRecord records the total time on the GPU as 0.19990 ms. Also I get 294 precision errors in the range (-32 to 32).

“ptxas info : Used 12 registers, 1200+0 bytes smem, 64 bytes cmem[0]”

  1. For iters = 438, the cuda_profiler gives the GPU time as 6.432 and CPU time as 32 with occupancy as 0.167. The cudaEventRecord records the total time on the GPU as 0.19840 ms. Also I get 1687 errors. The values are completely different.

“ptxas info : Used 12 registers, 3504+0 bytes smem, 64 bytes cmem[0]”

  1. For iters = 584, the cuda_profiler gives the GPU time as 8.64 and CPU time as 34 with occupancy as 0.104. The cudaEventRecord records the total time on the GPU as 0.19728 ms. Also I get 1687 errors. The values are completely different.

“ptxas info : Used 12 registers, 4672+0 bytes smem, 64 bytes cmem[0]”

  1. For iters = 1168, the cuda_profiler gives the GPU time as 8.672 and CPU time as 41 with occupancy as 0.104. The cudaEventRecord records the total time on the GPU as 0.28381 ms. Also I get 1687 errors. The values are completely different.

“ptxas info : Used 12 registers, 9344+0 bytes smem, 64 bytes cmem[0]”

  1. For iters = 1750, the cuda_profiler gives the GPU time as 2794.02 and CPU time as 2821 with occupancy as 0.062. The cudaEventRecord records the total time on the GPU as 2.98550 ms. Also I get 1735 errors. The values are completely different.

“ptxas info : Used 12 registers, 14000+0 bytes smem, 64 bytes cmem[0]”

Questions: (1) How do I rectify the precision errors?

(2) When the iterations increases, the kernel produces incorrect results. How do I remove these errors?

(3) The code is structured such that the shared memory is dependednt on the no. of iterations, which may be a incorrect choice. ANy suggestions to use less shared memory for large iterations such as 584 and above?

Thanks in advance :)

I am trying to use this kernel for different matrix sizes.

My kernel looks like this:

int num_blocks = (nac+threads-1)/threads; //nac = 1900, threads = 32

recon_reduce<<< num_blocks, threads>>>(mtxa_imd, xsha_d, ysha_d, fac_d);

const int threads =32; 

const int iters =150; //change iters depending on matrix size

const int nact=1900;

__global__ void recon_reduce(float* im, float* x, float* y, float* out)

{

	int tid = threadIdx.x + blockIdx.x*blockDim.x;

	

	float val = 0;

	int index1 = 0;

	int index2 = 0;

	__shared__ float x_val[iters];

	__shared__ float y_val[iters];

	

	x_val[tid] = x[tid];

	y_val[tid] = y[tid];

	

#pragma unroll

		for(int i = 0; i < iters; i++)

		{

			index1 = 2*nact*i+tid;

			index2 = (2*i+1)*nact+tid;

			val += (im[index1] * x_val[i]) + (im[index2] * y_val[i]);

		}		

		out[tid] = val;

}

Observations:

  1. For iters = 150, the cuda_profiler gives the GPU time as 8.64 and CPU time as 34 with occupancy as 0.167. The cudaEventRecord records the total time on the GPU as 0.19990 ms. Also I get 294 precision errors in the range (-32 to 32).

“ptxas info : Used 12 registers, 1200+0 bytes smem, 64 bytes cmem[0]”

  1. For iters = 438, the cuda_profiler gives the GPU time as 6.432 and CPU time as 32 with occupancy as 0.167. The cudaEventRecord records the total time on the GPU as 0.19840 ms. Also I get 1687 errors. The values are completely different.

“ptxas info : Used 12 registers, 3504+0 bytes smem, 64 bytes cmem[0]”

  1. For iters = 584, the cuda_profiler gives the GPU time as 8.64 and CPU time as 34 with occupancy as 0.104. The cudaEventRecord records the total time on the GPU as 0.19728 ms. Also I get 1687 errors. The values are completely different.

“ptxas info : Used 12 registers, 4672+0 bytes smem, 64 bytes cmem[0]”

  1. For iters = 1168, the cuda_profiler gives the GPU time as 8.672 and CPU time as 41 with occupancy as 0.104. The cudaEventRecord records the total time on the GPU as 0.28381 ms. Also I get 1687 errors. The values are completely different.

“ptxas info : Used 12 registers, 9344+0 bytes smem, 64 bytes cmem[0]”

  1. For iters = 1750, the cuda_profiler gives the GPU time as 2794.02 and CPU time as 2821 with occupancy as 0.062. The cudaEventRecord records the total time on the GPU as 2.98550 ms. Also I get 1735 errors. The values are completely different.

“ptxas info : Used 12 registers, 14000+0 bytes smem, 64 bytes cmem[0]”

Questions: (1) How do I rectify the precision errors?

(2) When the iterations increases, the kernel produces incorrect results. How do I remove these errors?

(3) The code is structured such that the shared memory is dependednt on the no. of iterations, which may be a incorrect choice. ANy suggestions to use less shared memory for large iterations such as 584 and above?

Thanks in advance :)

suggestions ?

suggestions ?

For iters = 438, I tried using 8 threads and the kernel breaks at blockIdx.x = 57

Program received signal CUDA_EXCEPTION_5, Warp Out-of-range Address.

[Switching to CUDA Kernel 0 (<<<(57,0),(0,0,0)>>>)]

0x0000000013ef1cf8 in recon_reduce<<<(238,1),(8,1,1)>>> (im=0xfc00340000, x=0xfc009a0000, y=0xfc009a0800, out=0xfc02426000)

    at recon_mtx_tests_kernel.cu:33

For iters = 438, I tried using 8 threads and the kernel breaks at blockIdx.x = 57

Program received signal CUDA_EXCEPTION_5, Warp Out-of-range Address.

[Switching to CUDA Kernel 0 (<<<(57,0),(0,0,0)>>>)]

0x0000000013ef1cf8 in recon_reduce<<<(238,1),(8,1,1)>>> (im=0xfc00340000, x=0xfc009a0000, y=0xfc009a0800, out=0xfc02426000)

    at recon_mtx_tests_kernel.cu:33

Seriously?

Look at this code:

int tid = threadIdx.x + blockIdx.x*blockDim.x;

....

__shared__ float x_val[iters];

        __shared__ float y_val[iters];

x_val[tid] = x[tid];

        y_val[tid] = y[tid];

Want to hazard a guess why an out-of-bounds error might occur on x_val or y_val when tid > iters?

Also, why are you persisting with these nonsensical block sizes? The warp size in CUDA is and has always been 32. You need blocks with a round mulitple of 32 threads per block to have even a hope of reaching respectable performance levels…

Seriously?

Look at this code:

int tid = threadIdx.x + blockIdx.x*blockDim.x;

....

__shared__ float x_val[iters];

        __shared__ float y_val[iters];

x_val[tid] = x[tid];

        y_val[tid] = y[tid];

Want to hazard a guess why an out-of-bounds error might occur on x_val or y_val when tid > iters?

Also, why are you persisting with these nonsensical block sizes? The warp size in CUDA is and has always been 32. You need blocks with a round mulitple of 32 threads per block to have even a hope of reaching respectable performance levels…

@avidday: qn: what would be the best way to initialize the shared memory in this example? should the value of x_val and y_val be initialized to iters+threads, so that inside the loop, x_val[i] does not result in an error?

Using this approach:

__global__ void recon_reduce(float* im, float* x, float* y, float* out)

{

	int tid = threadIdx.x + blockIdx.x*blockDim.x;

	

	float val = 0;

	int index1 = 0;

	int index2 = 0;

	// __shared__ float x_val[iters+threads];

	// __shared__ float y_val[iters+threads];

	float x_val;

	float y_val;

	// x_val[tid] = x[tid];

	// y_val[tid] = y[tid];

	

#pragma unroll

		for(int i = 0; i < iters; i++)

		{

			x_val = x[i];

			y_val = y[i];

			index1 = 2*nact*i+tid;

			index2 = (2*i+1)*nact+tid;

			val += (im[index1] * x_val) + (im[index2] * y_val);

		}		

		out[tid] = val;

}

@avidday: qn: what would be the best way to initialize the shared memory in this example? should the value of x_val and y_val be initialized to iters+threads, so that inside the loop, x_val[i] does not result in an error?

Using this approach:

__global__ void recon_reduce(float* im, float* x, float* y, float* out)

{

	int tid = threadIdx.x + blockIdx.x*blockDim.x;

	

	float val = 0;

	int index1 = 0;

	int index2 = 0;

	// __shared__ float x_val[iters+threads];

	// __shared__ float y_val[iters+threads];

	float x_val;

	float y_val;

	// x_val[tid] = x[tid];

	// y_val[tid] = y[tid];

	

#pragma unroll

		for(int i = 0; i < iters; i++)

		{

			x_val = x[i];

			y_val = y[i];

			index1 = 2*nact*i+tid;

			index2 = (2*i+1)*nact+tid;

			val += (im[index1] * x_val) + (im[index2] * y_val);

		}		

		out[tid] = val;

}

Sorry I don’t understand what you are trying to ask.

As the code in post #14 is written (leaving aside the indexing errors I pointed out above), it can only ever work correctly when the number of threads per block == iters. If it is not obvious why, then I suggest you go and re-read the bits of CUDA documentation that cover the execution model and shared memory.

Sorry I don’t understand what you are trying to ask.

As the code in post #14 is written (leaving aside the indexing errors I pointed out above), it can only ever work correctly when the number of threads per block == iters. If it is not obvious why, then I suggest you go and re-read the bits of CUDA documentation that cover the execution model and shared memory.

Here is some food for thought:

__global__ void recon_reduce

(const float* im, const float* x, const float* y, float* out)

{

        int tid = threadIdx.x + blockIdx.x*blockDim.x;

float val = 0;

	int index1 = tid, index2 = tid + iters * nact;

__shared__ float x_val[32];

        __shared__ float y_val[32];

	for(int pos = 0; pos < iters; pos += 32) {

		if (threadIdx.x < 32) {

			x_val[threadIdx.x] = x[pos + threadIdx.x];

			y_val[threadIdx.x] = y[pos + threadIdx.x];

		}

		__syncthreads();

		for(int i = 0; (i < 32); i++) {

			val += im[index1] * x_val[i];

			val += im[index2] * y_val[i];

			index1 += nact;

			index2 += nact;

		}               

	}

	out[tid] = val;

}

which when combined with this:

#!/usr/bin/env python

from pycuda import driver, compiler, gpuarray, tools

import pycuda.autoinit

import numpy as np

im = np.asarray(np.random.rand(1600,320),dtype=np.float32,order="F")

x = np.random.rand(160).astype(np.float32)

y = np.random.rand(160).astype(np.float32)

blocksz = (64,1,1)

gridsz = (25,1)

im_ = gpuarray.to_gpu(im)

x_ = gpuarray.to_gpu(x)

y_ = gpuarray.to_gpu(y)

out_ = gpuarray.empty((1600),dtype=np.float32)

vmod = driver.module_from_file("vivek.cubin")

vkernel = vmod.get_function("_Z12recon_reducePKfS0_S0_Pf")

vtime = vkernel(im_, x_, y_, out_, block=blocksz, grid=gridsz, time_kernel=True)

out = out_.get()

refout = np.dot(im, np.hstack([x,y]))

abserr = np.abs(out-refout)

relerr = abserr / np.maximum(1e-10, np.abs(refout))

print "Solution time = ", vtime

print "Maximum absolute error = ", abserr.max()

print "Maximum relative error = ", relerr.max()

does this:

avidday@cuda:~$ nvcc --cubin -arch=sm_20 -Xptxas="-v" vivek.cu 

ptxas info    : Compiling entry function '_Z12recon_reducePKfS0_S0_Pf' for 'sm_20'

ptxas info    : Used 18 registers, 256+0 bytes smem, 64 bytes cmem[0]

avidday@cuda:~$ python vivek.py 

Solution time =  9.08374786377e-05

Maximum absolute error =  9.15527e-05

Maximum relative error =  1.13266e-06

Here is some food for thought:

__global__ void recon_reduce

(const float* im, const float* x, const float* y, float* out)

{

        int tid = threadIdx.x + blockIdx.x*blockDim.x;

float val = 0;

	int index1 = tid, index2 = tid + iters * nact;

__shared__ float x_val[32];

        __shared__ float y_val[32];

	for(int pos = 0; pos < iters; pos += 32) {

		if (threadIdx.x < 32) {

			x_val[threadIdx.x] = x[pos + threadIdx.x];

			y_val[threadIdx.x] = y[pos + threadIdx.x];

		}

		__syncthreads();

		for(int i = 0; (i < 32); i++) {

			val += im[index1] * x_val[i];

			val += im[index2] * y_val[i];

			index1 += nact;

			index2 += nact;

		}               

	}

	out[tid] = val;

}

which when combined with this:

#!/usr/bin/env python

from pycuda import driver, compiler, gpuarray, tools

import pycuda.autoinit

import numpy as np

im = np.asarray(np.random.rand(1600,320),dtype=np.float32,order="F")

x = np.random.rand(160).astype(np.float32)

y = np.random.rand(160).astype(np.float32)

blocksz = (64,1,1)

gridsz = (25,1)

im_ = gpuarray.to_gpu(im)

x_ = gpuarray.to_gpu(x)

y_ = gpuarray.to_gpu(y)

out_ = gpuarray.empty((1600),dtype=np.float32)

vmod = driver.module_from_file("vivek.cubin")

vkernel = vmod.get_function("_Z12recon_reducePKfS0_S0_Pf")

vtime = vkernel(im_, x_, y_, out_, block=blocksz, grid=gridsz, time_kernel=True)

out = out_.get()

refout = np.dot(im, np.hstack([x,y]))

abserr = np.abs(out-refout)

relerr = abserr / np.maximum(1e-10, np.abs(refout))

print "Solution time = ", vtime

print "Maximum absolute error = ", abserr.max()

print "Maximum relative error = ", relerr.max()

does this:

avidday@cuda:~$ nvcc --cubin -arch=sm_20 -Xptxas="-v" vivek.cu 

ptxas info    : Compiling entry function '_Z12recon_reducePKfS0_S0_Pf' for 'sm_20'

ptxas info    : Used 18 registers, 256+0 bytes smem, 64 bytes cmem[0]

avidday@cuda:~$ python vivek.py 

Solution time =  9.08374786377e-05

Maximum absolute error =  9.15527e-05

Maximum relative error =  1.13266e-06