Finding suitable cuBLAS function and half-spaces swap algorithm strategy discussion

Hi everyone,

I’m looking for a suitable cuBLAS function which perform (double complex) element-wise vector multiplication. I have try to search for an answer and people said that there is no such function like it in cuBLAS. However, there are someone here c++ - Element-wise vector-vector multiplication in BLAS? - Stack Overflow said that using sbmv function and treat the vector as diagonal matrix would work. But cuBLAS lib only provide Dsbmv and Csbmv, therefore I cannot use it.
I wrote a naive custom kernel to perform this operation but the time it need to complete didn’t satisfy me. For example: my kernel took twice the amount of time compared to Zdotc() kernel (which not only do the element-wise multiplication but also do the element sum after that). To be honest, it greatly discourages me to use my own kernel! :))

On the other hand, can anyone suggest me a better algorithm for swapping half spaces within each dimension of a matrix? I’m writting a FFTShift function in CUDA environment which behave similar to matlab fftshift() function. What it does is “swaps half-spaces of matrix X along each dimension”. To simplify the task, I’ve narrowed down the scope to swapping spaces for 3D matrix with even number of elements on each dimension (NXNYNZ with NX%2 = NY%2 = NZ%2 = 0). And my strategy to do it is executing 3 kernels consecutively: cuFFTShiftZ() → cuFFTShiftY() → cuFFTShiftX() which perform the task on each dimension. Here are my kernels:

// ShiftZ simply do the swapping elements of 2 vector with each other (same as cublasZswap)
__global__ void cuFFTShiftZ(cuDoubleComplex *a, cuDoubleComplex *b){
	int index = blockIdx.x*blockDim.x + threadIdx.x;
	cuDoubleComplex temp;
	temp = a[index];
	a[index] = b[index];
	b[index] = temp;
}
// ShiftY locates elements belong to the first half of dimension Y and swap them
__global__ void cuFFTShiftY(cuDoubleComplex *a){
	int index = blockIdx.x*blockDim.x + threadIdx.x;
	int total = NX*NY;
	int yindex = index % total;
	cuDoubleComplex temp;
	if (yindex < (total / 2)){
		temp = a[index];
		a[index] = a[index + (total / 2)];
		a[index + (total / 2)] = temp;
	}
}
// ShiftX locates elements belong to the first half of dimension X and swap them
__global__ void cuFFTShiftX(cuDoubleComplex *a){
	int index = blockIdx.x*blockDim.x + threadIdx.x;
	int xindex = index % NX;
	cuDoubleComplex temp;
	if (xindex < (NX / 2)){
		temp = a[index];
		a[index] = a[index + (NX / 2)];
		a[index + (NX / 2)] = temp;
	}
}

And I’m going execute them as below:

// d_a is the input matrix which is stored with NX as leading dimension
	// NX, NY, NZ in this example is 30, 64, 64
	dim3 numThreads = 128;
	dim3 numBlocks = NX*NY*NZ / numThreads;
	cuFFTShiftZ <<<numBlocks / 2, numThreads >>>(d_a, d_a + (NX*NY*NZ / 2)); // half-length of the array
	cuFFTShiftY <<<numBlocks, numThreads >>>(d_a); // perform on full length array
	cuFFTShiftX <<<numBlocks, numThreads >>>(d_a); // perform on full length array

It works. But there are 3 things which greatly concern me:

  1. Why with such a simple task like interchanges the elements of 2 vectors, my kernel take twice the amount of time than cublasZswap()? (I used visual profiler to monitor them) How could it, the cuBLAS kernel, be so fast?!
  2. The ShiftY and ShiftZ kernels waste to much resource. Half of the threads do nothing! Can anyone show me how or what can I do to improve this?
  3. With this approach, I have to read-write to the array located in global memory 3 times! This is so waste of bandwidth and highly redundancy.

I’ve tried to implement another method: instead of letting one thread update value for 2 elements, I make each thread update value for only one element. Here is my alternate kernel:

__global__ void cuFFTShiftX(cuDoubleComplex *a){
	int tid = blockIdx.x*blockDim.x*blockDim.y + threadIdx.y*blockDim.x + threadIdx.x;
	int tidB = threadIdx.y*blockDim.x + threadIdx.x;
	__shared__ cuDoubleComplex sdata[960];
	sdata[tidB] = a[tid];
	__syncthreads();
	if (threadIdx.x < (NX / 2)){
		a[tid] = sdata[tidB + (NX / 2)];
	}
	else{
		a[tid] = sdata[tidB - (NX / 2)];
	}
}

And parameters for kernel execution

dim3 numT(30, 32); // NX = 30
cuFFTShiftX <<<128, numT >>>(d_a);

However, it runs even slower than the resource-wasted version! I doubt that it could be because of 3 things:

  1. The call on __syncthreads()
  2. Bank conflict
  3. And each warp now only have 30 threads running instead of 32 -> tried ``` dim3 numT(32,30); ```

    and change the code above to adapt to it but the result is still the same

Thanks everyone so much for helping me out with this! :)

regarding your elementwise multiplication, you should be able to write a kernel that is as fast as cublas dot. Since you haven’t provided your code, I’m not sure what else can be said.

The shift operation you are trying to perform is a data movement situation. Your goal, as you have already surmised, is to write a kernel that reads each item only once, and writes each item only once, and performs the move operation. This means you will need to figure out the final (input-to-final-output) indexing in a single operation, rather than in steps or pieces. Furthermore for optimality, all your reads should be properly coalesced reads and all your writes should be properly coalesced writes. At that point, there isn’t any equivalent CUBLAS operation that could/should be faster.

You might want to study the matrix transpose blog here:
[url]https://devblogs.nvidia.com/efficient-matrix-transpose-cuda-cc/[/url]
to get some ideas. I don’t have a ready-made code to suggest for the exact input-to-output indexing for this kind of shift operation.

1 Like

Hi txbob,

you can implement this sample code and observe the result on visual profiler:

#include "cuda_runtime.h"
#include "device_launch_parameters.h"

#include <stdio.h>
#include <cublas_v2.h>
#include <math.h>
#include <stdlib.h>
#include <helper_functions.h>
#include <helper_cuda.h>

#define errChk if (cudaStatus != cudaSuccess) {fprintf(stderr, "CUDA failed: %s\n", cudaGetErrorString(cudaStatus)); goto Error;};
#define BlaserrChk if (blasstat != CUBLAS_STATUS_SUCCESS) {fprintf(stderr, "cuBLAS failed: %s \n", _cudaGetErrorEnum(blasstat)); goto Error;};

__global__ void ewMulti(cuDoubleComplex *a, cuDoubleComplex *b){
	int tid = blockIdx.x*blockDim.x + threadIdx.x;
	double tmp = a[tid].x;
	a[tid].x = a[tid].x*b[tid].x - a[tid].y*b[tid].y;
	a[tid].y = tmp*b[tid].y + a[tid].y*b[tid].x;
}

__global__ void SwapKernel(cuDoubleComplex *a, cuDoubleComplex *b){
	int index = blockIdx.x*blockDim.x + threadIdx.x;
	cuDoubleComplex temp;
	temp = a[index];
	a[index] = b[index];
	b[index] = temp;
}

int main()
{
	cuDoubleComplex *h_x = new cuDoubleComplex[16384];
	cuDoubleComplex *h_xx = new cuDoubleComplex[16384];
	for (int i = 0; i < 16384; i++){
		h_x[i].x = rand() % 10; h_x[i].y = rand() % 10;
		h_xx[i].x = rand() % 10; h_xx[i].y = rand() % 10;
	}
	cuDoubleComplex h_y;

	cuDoubleComplex *d_x;
	cuDoubleComplex *d_xx;
	cuDoubleComplex *d_y;

	cudaError_t cudaStatus;
	cublasStatus_t blasstat;
	cublasHandle_t handle;

	int numThreads = 128;
	int numBlocks = 16384 / numThreads;

	// cuBLAS handler
	blasstat = cublasCreate(&handle); BlaserrChk;
	blasstat = cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_DEVICE); BlaserrChk;

	// Choose which GPU to run on, change this on a multi-GPU system.
	cudaStatus = cudaSetDevice(0); errChk;

	// Allocate GPU buffers for vectors
	cudaStatus = cudaMalloc((void**)&d_x, 16384 * sizeof(cuDoubleComplex)); errChk;
	cudaStatus = cudaMalloc((void**)&d_xx, 16384 * sizeof(cuDoubleComplex)); errChk;
	cudaStatus = cudaMalloc((void**)&d_y, sizeof(cuDoubleComplex)); errChk;
	
	// Copy input vectors from host memory to GPU buffers.
	cudaStatus = cudaMemcpy(d_x, h_x, 16384 * sizeof(cuDoubleComplex), cudaMemcpyHostToDevice); errChk;
	cudaStatus = cudaMemcpy(d_xx, h_xx, 16384 * sizeof(cuDoubleComplex), cudaMemcpyHostToDevice); errChk;
	
	blasstat = cublasZdotc(handle, 16384, d_x, 1, d_xx, 1, d_y); BlaserrChk;
	
	ewMulti <<<numBlocks, numThreads>>>(d_x, d_xx);
	cudaStatus = cudaGetLastError(); errChk;

	SwapKernel<<<numBlocks, numThreads>>>(d_x, d_xx);
	cudaStatus = cudaGetLastError(); errChk;

	blasstat = cublasZswap(handle, 16384, d_x, 1, d_y, 1); BlaserrChk;

	// Copy output vector from GPU buffer to host memory.
	cudaStatus = cudaMemcpy(&h_y, d_y, sizeof(cuDoubleComplex), cudaMemcpyDeviceToHost); errChk;
	cudaStatus = cudaMemcpy(h_x, d_x, 16384 * sizeof(cuDoubleComplex), cudaMemcpyDeviceToHost); errChk;
	printf("y = x'*x = %f %fi\n",
		h_y.x, h_y.y);

	cudaStatus = cudaDeviceSynchronize(); errChk;

Error:
	cudaFree(d_x);
	cudaFree(d_y);
	delete[] h_x;
	delete[] h_xx;
	blasstat = cublasDestroy(handle);
	// cudaDeviceReset must be called before exiting in order for profiling and
	// tracing tools such as Nsight and Visual Profiler to show complete traces.
	cudaStatus = cudaDeviceReset();
	if ((cudaStatus != cudaSuccess) || (blasstat != CUBLAS_STATUS_SUCCESS))
		return 1;
	return 0;
}

I get 33.7us for my ewMulti kernel and total of 13.5us for 2 kernels (dot and reduce) of cublasZdotc().
Same result for the swapping kernel too (25us for custom kernel vs 9.8us for cublas kernel). Test on Quadro K4200.

I can’t think of any faster way to perform this operation but still, the cublasZswap() beat my kernel from head to heel! :))
About the shift function, there is someone who already did it before me. His approach is, as you said, to pinpoint directly where each element should be. However, his method only applicable to “operate only on 1D arrays of even sizes and 2D/3D arrays with power-of-two sizes and unified dimensionality”, which mean the applicable 3D matrix must be a “cubic”. If I want to use his lib, I have to pad my data to meet his requirements, which could even n-folds my data. Besides, there is one more thing that greatly concern me. Here is his kernel: https://github.com/marwan-abdellah/cufftShift/blob/master/Src/CUDA/Kernels/out-of-place/cufftShift_3D_OP.cu
With 3 if statement layers, I suppose that his kernel would induce a huge warp divergence. Could you take a look and verify my speculation?!
I will take your suggestions into consideration when writing next version of this kernel. Actually, this shift kernel is built in order to satisfy the condition:

And thanks for the link, it would be very helpful for my learning course!
Oh, and there are some random stuffs I would like to ask:

  1. Do I have to call cudaStreamDestroy() before DeviceReset()? Because if I place it after, it will throw an error about mem access violation. But I failed to find a warning about this in any of provided nvidia documents
  2. If I initialize the numThread parameter for kernel like this: ``` dim3 numThreads(30, 32) ``` will my warp now only consist of 30 threads or is it still be a group of 32 threads --> Section 4.1. SIMT Architecture in CUDA C Programming Guide
  3. My data type is double complex, my quadro K4200 has 64 fp64 cores and 192 fp32 cores per each SM, should I decide the number of threads rationally to the number of fp64 cores or fp32 cores or something else? --> Answer here: https://devtalk.nvidia.com/default/topic/897696/relationship-between-threads-and-gpu-core-units/

Do this instead:

__global__ void ewMulti(cuDoubleComplex * __restrict__ a, const  cuDoubleComplex * __restrict__ b){
	int tid = blockIdx.x*blockDim.x + threadIdx.x;
        cuDoubleComplex my_a = a[tid];
        cuDoubleComplex my_b = b[tid];
        cuDoubleComplex res;
	res.x = my_a.x*my_b.x - my_a.y*my_b.y;
	res.y = my_a.x*my_b.y + my_a.y*my_b.x;
        a[tid] = res;
}

and make sure you are building a release project not a debug project.

This is memory bound code so the critical factor is to manage your memory loads and stores properly. (there are no perfectly coalesced reads or writes anywhere in your ewMulti kernel. And the profiler can tell you this. Look at the gld_efficiency and gst_efficiency metrics) The Zdot kernel will actually have an advantage over this case, because it need not do any global memory stores for the entire result array. So it’s not quite a fair comparison. But this code should get you closer than where you are now.

Regarding your last 3 questions:

  1. I don’t see you using streams anywhere. If you want to call cudaStreamDestroy, you should do it before calling cudaDeviceReset. cudaDeviceReset wipes the slate clean as far as the GPU is concerned, so all allocations and anything previously created is destroyed by that operation. If you then attempt to delete or destroy something, you will get an error. From the documentation:

https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html#group__CUDART__DEVICE_1gef69dd5c6d0206c2b8d099abac61f217

“Destroy all allocations and reset all state on the current device in the current process…Explicitly destroys and cleans up all resources associated with the current device in the current process. Any subsequent API call to this device will reinitialize the device.”

That seems fairly self explanatory. Not sure why you would say “But I didn’t notice that this is mentioned anything in provided nvidia documents”. I guess the only thing clearer might be a long laundry list of what is destroyed, but since it is everything that seems unnecessary to me. But if you destroy and clean up all resources, then it stands to reason that if you subsequently try and destroy one of those resources again, you might get an error.

  1. Warps will contain 32 threads. You will have 30 of them. The first warp will have 30 threads whose threadIdx.y value will be 0 and whose threadIdx.x value will increment from 0…29. After that there will be two more threads in the warp whose threadIdx.y value is 1 and whose threadIdx.x values are 0 and 1. This is covered in the documentation:

https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#simt-architecture
https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#thread-hierarchy

as well as other places on the web, for example:
https://stackoverflow.com/questions/15044671/how-is-the-2d-thread-blocks-padded-for-warp-scheduling/15044884#15044884

  1. This is one of the most commonly asked questions in CUDA programming. It is explained in probably dozens of places on the web. Choose the number of threads to either fit your problem size, or to saturate your GPU. Neither of these methods have any direct connection to things like the number of cores, or any other type of execution unit inside an SM. If you study any of the literally dozens of CUDA codes provided by NVIDIA, you will find none of them choosing threads or blocks based on CUDA cores, whether FP32 or FP64. The most common thread strategy is to choose one thread for each output (or, instead input) point in your data set. This is an example of choosing threads to fit the problem size, and works well for a great many types of algorithms. A method which might allow you to choose the threads based on GPU specifics (i.e. to saturate the GPU) might be the grid stride loop, which is explained here:

https://devblogs.nvidia.com/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/

A grid-stride loop that chooses at least 2048 * # of SMs as its total thread count (and the loops over the data set size with these threads) will also work well for a great many problems.

For more detailed, general treatments, read the first two answers here:

https://stackoverflow.com/questions/9985912/how-do-i-choose-grid-and-block-dimensions-for-cuda-kernels

1 Like

OMG! OMG! OMG! I’ve learnt a ton from your example and instruction! Seeing how you manipulate the code make me redefine everything in my understanding of coalesced reads or writes. I did read about it before but not fully/correctly understand how to “make it work”. All I knew is it would best to arrange data next to each other and perform R-W on them. But the way I access is also a big factor that affect my throughput performance. Now reading again those guides and instructions make much more sense to me.
Thanks a lot, txbob! You are a great help!!

p/s: somehow, just by changing from debug to release mode, my SwapKernel now has a better performance than cublas Zswap kernel. The way the compiler handle/optimize my code for the release build is really mysterious! :))

updated: You update your answer while I’m following exactly the same route to seek for an answer myself! :)
Alright, about Q1, I supposed that entities like cuBlas plan and stream are located on Host side so DeviceReset() won’t touch them. Besides, I couldn’t find anything regarding that problem because I was focusing only on literature about streams and forgot to read carefully definition of cudaDeviceReset either. My bad!
Q2 & Q3 popped up during last midnight while I was preparing/testing my demo code for you, so I didn’t have time to search for an answer myself. My apology for these necessary duplicated questions!