Using cublasCgemmStridedBatched for FFT Convolution, settings for the multiplication?

I am attempting to do FFT convolution using cuFFT and cuBlas.

I have everything up to the element-wise multiplication + sum procedure working. Rather than do the element-wise + sum procedure I believe it would be faster to use cublasCgemmStridedBatched.

I wish to multiply matrices AB=C. I am aware that cublasCgemmStridedBatched works in column major order, so after passed the multiplication is computed as:

A^T B^T = C^T

To get around this without transposing the result I anticipate this transpose and pass cublasCgemmStridedBatched:

B A

As

(A^T B^T)^T = (C^T)^T
BA = C

This changes the values that must be given to cublasCgemmStridedBatched.

My A matrix is the batched FFT of some images. The B matrix is a batch of FFT’ed kernels. I have successfully computed this in python using

A (Batched FFT Images) of shape [imgH x imgW,BatchSize,InC]
B (Batched FFT Kernels) of shape [imgH x imgW,InC,outC]

using numpy’s matmul as np.matmul(A,B).

When I do this using cublasCgemmStridedBatched I get the incorrect results. My settings can be seen below.

My thoughts on what the inputs need to be:

Since I am passing in BA, these get transposed by cublas to B^T A^T. This leaves me with m x k by k x n. My batch dimension I believe is going to be imgH x imgW. M is then outC, k is inC, and n is batchSize. The leading dimensions are always the columns so ldA will be outC, ldB with be batchSize, and ldC will be outC.

This leaves the strides. I believe these follow 3d indexing where the stride amount should bring you to the next channel / dimension. Thus I set the strides to the matrix sizes: outCinC for A, inCbatchSize for B, batchSize * outC for C.

//	 Runs but gives incorrect
    	cublasStat = cublasCgemmStridedBatched(handle,CUBLAS_OP_N,CUBLAS_OP_N,
    												  outC, batchSize, inC,
    												  & alpha,
    												  kerPtr, outC, inC*outC,
    												  imgsPtr, inC, inC*batchSize,
    												  & beta,
    												  d_C, outC, batchSize*outC,
    												  padH*padW);

What are the proper arguments to feed into cublasCgemmStridedBatched to get the correct results?

Heres a reproducible example:

#include <assert.h>
    #include <stdio.h>
    #include <stdlib.h>
    #include <string.h>
    
    // Include CUDA runtime and CUFFT
    #include <cuda_runtime.h>
    #include <cufft.h>
    #include <cuComplex.h>
    #include <math.h>
    #include "cufft.h"
    #include "cublas_v2.h"
    
    __global__
    void Pad2DArrayNoLoop(cuComplex* __restrict__ img, cuComplex* __restrict__ padded, int imgH, int imgW, int padH, int padW)
    {
    	int i = blockIdx.x * blockDim.x + threadIdx.x;
    	int j = blockIdx.y * blockDim.y + threadIdx.y;
    
    	if( i >= imgH && j >= imgW)
    	{
    		return;
    	}
    
    	padded[i*padW + j].x = img[i*imgW + j].x;
    	padded[i*padW + j].y = img[i*imgW + j].y;
    }
    
    __global__
    void Pad3DArrayNoLoop(cuComplex* __restrict__ img, cuComplex* __restrict__ padded, int batchSize, int imgH, int imgW, int padH, int padW)
    {
    	int i = blockIdx.x * blockDim.x + threadIdx.x;
    	int j = blockIdx.y * blockDim.y + threadIdx.y;
    	int k = blockIdx.z * blockDim.z + threadIdx.z;
    
    	if( i >= batchSize && j >= imgH && k >= imgW)
    	{
    		return;
    	}
    
    	padded[i*padH*padW + j*padW + k].x = img[i*imgW*imgH + j*imgW + k].x;
    	padded[i*padH*padW + j*padW + k].y = img[i*imgW*imgH + j*imgW + k].y;
    }
    
    void initArrZero(cuComplex *arr, int arrH, int arrW, int batchSize=1)
    {
    	for(int i = 0; i < arrH * arrW * batchSize; i++)
    	{
    		arr[i].x = 0;
    		arr[i].y = 0;
    	}
    }
    
    void initArr_I(cuComplex *arr, int arrH, int arrW)
    {
    	for(int i = 0; i < arrH * arrW; i++)
    	{
    		arr[i].x = i;
    		arr[i].y = 0;
    	}
    }
    
    void initArr_Ibatch(cuComplex *arr, int arrH, int arrW, int batchSize, int addVal = 0)
    {
    	for(int i = 0; i < batchSize * arrH * arrW; i++)
    	{
    		arr[i].x = i + addVal;
    		arr[i].y = 0;
    	}
    }
    
    void printArr(cuComplex * arr, int arrH, int arrW)
    {
    	for(int i = 0; i < arrH; i++)
    	{
    		for(int j = 0; j < arrW; j++)
    		{
    			printf("(%f, %f) ", arr[i*arrW + j].x, arr[i*arrW + j].y);
    		}
    		printf("\n");
    	}
    }
    
    void printArr3D(cuComplex * arr, int arrH, int arrW, int batchSize)
    {
    	for(int i = 0; i < batchSize; i++)
    	{
    		for(int j = 0; j < arrH; j++)
    		{
    			for(int k = 0; k < arrW; k++)
    			{
    				printf("(%f,%f) ", arr[i*arrH*arrW + j*arrW + k].x, arr[i*arrH*arrW + j*arrW + k].y);
    			}
    			printf("\n");
    		}
    		printf("\n");
    	}
    }
    
    
    void FFT_Conv_Batch(int batchSize, int imgH, int imgW, int kerH, int kerW, int inC, int outC)
    {
    	cuComplex *h_imgs;
    	cuComplex *d_imgs;
    	cuComplex *h_ker;
    	cuComplex *d_ker;
    	cuComplex *h_padded;
    	cuComplex *d_padded;
    	cuComplex *h_kerPadded;
    	cuComplex *d_kerPadded;
    	cuComplex *h_C;
    	cuComplex *d_C;
    
    	cufftHandle fftPlanFwdImg, fftPlanFwdKernel, fftPlanInv;
    	cufftComplex *d_ImgsSpectral, *d_KernelSpectral;
    
    	int padH = imgH + kerH - 1;
    	int padW = imgW + kerW - 1;
    
    	static const int nRank = 2;
    	// For no Padding
    	int n[nRank] = {padH,padW};
    
    	cufftResult_t stat;
    	printf("Creating cuFFT plans\n");
    	stat = cufftPlanMany(& fftPlanFwdImg,nRank,n,NULL,1,0,NULL,1,0,CUFFT_C2C,batchSize*inC);
    
    	if( stat != CUFFT_SUCCESS )
    	{
    		printf("Cufft plan didn't init\n");
    	}
    
    	stat = cufftPlanMany(& fftPlanFwdKernel,nRank,n,NULL,1,0,NULL,1,0,CUFFT_C2C,inC*outC);
    
    	if( stat != CUFFT_SUCCESS )
    	{
    		printf("Cufft plan didn't init\n");
    	}
    
    	stat = cufftPlanMany(& fftPlanInv, nRank, n, NULL,1,0,NULL,1,0,CUFFT_C2C,batchSize*outC);
    
    	if( stat != CUFFT_SUCCESS )
    	{
    		printf("Cufft plan didn't init\n");
    	}
    
    	printf("Allocating Memory on CPU\n");
    	h_imgs = (cuComplex *) malloc(batchSize * inC * imgH * imgW * sizeof(cuComplex));
    	h_padded = (cuComplex *) malloc(batchSize * inC * padH * padW * sizeof(cuComplex));
    	h_ker = (cuComplex *) malloc(inC * outC * kerH * kerW * sizeof(cuComplex));
    	h_kerPadded = (cuComplex *) malloc(inC * outC * padH * padW * sizeof(cuComplex));
    	h_C = (cuComplex *) malloc(batchSize * outC * padH * padW * sizeof(cuComplex));
    
    	printf("Allocating Memory on GPU\n");
    	cudaMalloc((void **) &d_imgs, batchSize * inC * imgH * imgW * sizeof(cuComplex));
    	cudaMalloc((void **) &d_padded, batchSize * inC * padH * padW * sizeof(cuComplex));
    	cudaMalloc((void **) &d_ker, inC * outC * kerH * kerW * sizeof(cuComplex));
    	cudaMalloc((void **) &d_kerPadded, inC * outC * padH * padW * sizeof(cuComplex));
    	cudaMalloc((void **) &d_ImgsSpectral, batchSize * inC * padH * padW * sizeof(cuComplex));
    	cudaMalloc((void **) &d_KernelSpectral, inC * outC * padH * padW * sizeof(cuComplex));
    	cudaMalloc((void **) &d_C, batchSize * outC * padH * padW * sizeof(cuComplex));
    
    	// To make the kernel different VS filter
    	int kerAddVal = 2;
    	initArr_Ibatch(h_imgs,imgH,imgW,batchSize*inC);
    	initArr_Ibatch(h_ker,kerH,kerW,inC*outC,kerAddVal);
    	initArrZero(h_padded,padH,padW,batchSize*inC);
    	initArrZero(h_kerPadded,padH,padW,inC*outC);
    	initArrZero(h_C,padH,padW,batchSize*outC);
    
    	// To test with complex input
    //	h_imgs[3].y=2.0f;
    
    	cudaMemcpyAsync(d_imgs,h_imgs, batchSize * inC * imgH * imgW * sizeof(cuComplex),cudaMemcpyHostToDevice);
    	cudaMemcpyAsync(d_ker,h_ker,inC * outC * kerH * kerW * sizeof(cuComplex), cudaMemcpyHostToDevice);
    	cudaMemcpyAsync(d_padded,h_padded,batchSize * inC * padH * padW * sizeof(cuComplex),cudaMemcpyHostToDevice);
    	cudaMemcpyAsync(d_kerPadded, h_kerPadded, inC * outC * padH * padW * sizeof(cuComplex), cudaMemcpyHostToDevice);
    	cudaMemcpyAsync(d_C,h_C,batchSize * outC * padH * padW * sizeof(cuComplex), cudaMemcpyHostToDevice);
    
    //	printf("Img\n");
    //	printArr3D(h_imgs,imgH,imgW,batchSize*inC);
    //	printf("Kers\n");
    //	printArr3D(h_ker,kerH,kerW,inC*outC);
    //	printf("Padded Empty Arr\n");
    //	printArr3D(h_padded,padH,padW,batchSize*inC);
    
    	dim3 blockSizes(batchSize*inC,4,4);
    	dim3 numBlocks( (batchSize*inC + blockSizes.x - 1)/ blockSizes.x, (imgH + blockSizes.y - 1)/ blockSizes.y, (imgW + blockSizes.z - 1)/ blockSizes.z);
    	Pad3DArrayNoLoop<<<numBlocks,blockSizes>>>(d_imgs, d_padded, batchSize, imgH, imgW, padH, padW);
    
    	blockSizes.x = inC*outC, blockSizes.y = 3; blockSizes.z = 3;
    	numBlocks.x = (inC*outC + blockSizes.x - 1) / blockSizes.x;
    	numBlocks.y = (kerH + blockSizes.y - 1) / blockSizes.y;
    	numBlocks.z = (kerW + blockSizes.z - 1) / blockSizes.z;
    	Pad3DArrayNoLoop<<<numBlocks,blockSizes>>>(d_ker,d_kerPadded,batchSize,kerH,kerW,padH,padW);
    
    	cudaMemcpyAsync(h_padded,d_padded,batchSize*inC*padH*padW*sizeof(cuComplex),cudaMemcpyDeviceToHost);
    	cudaMemcpyAsync(h_kerPadded,d_kerPadded,inC*outC*padH*padW*sizeof(cuComplex),cudaMemcpyDeviceToHost);
    
    //	printf("Copied to padding\n");
    //	printf("Images\n");
    //	printArr3D(h_padded,padH,padW,batchSize*inC);
    //	printf("Kernels\n");
    //	printArr3D(h_kerPadded,padH,padW,inC*outC);
    
    	printf("Executing FFT\n");
    	stat = cufftExecC2C(fftPlanFwdImg, (cufftComplex *) d_padded, (cufftComplex *) d_ImgsSpectral,CUFFT_FORWARD);
    
    	if(stat != CUFFT_SUCCESS)
    	{
    		printf("FFT execution error.\n");
    	}
    
    	stat = cufftExecC2C(fftPlanFwdKernel, (cufftComplex *) d_kerPadded, (cufftComplex *) d_KernelSpectral,CUFFT_FORWARD);
    
    	if(stat != CUFFT_SUCCESS)
    	{
    		printf("Kernel FFT Execution error.\n");
    	}
    
    	cudaMemcpyAsync(h_padded,d_ImgsSpectral,batchSize*inC*padH*padW*sizeof(cuComplex),cudaMemcpyDeviceToHost);
    	cudaMemcpyAsync(h_kerPadded,d_KernelSpectral,inC*outC*padH*padW*sizeof(cuComplex),cudaMemcpyDeviceToHost);
    
    	printf("FFT Image Results\n");
    	printArr3D(h_padded,padH,padW,batchSize*inC);
    //
    	printf("FFT Kernel Results\n");
    	printArr3D(h_kerPadded,padH,padW,inC*outC);
    
    	printf("Element wise Multiplication\n");
    
    	cublasStatus_t cublasStat;
    	cublasHandle_t handle;
    	cuComplex alpha = {1.0f, 0.0f};
    	cuComplex beta = {0.0f, 0.0f};
    	// Doesn't matter, the matrices are square
    //	int dim = padH;
    
    //	/*
    //	 * Reference to what's going on in the cublasCgemmStridedBatched
    //	 int M = colB, N = rowA, K = colA, ldA = colB, ldB = colA, ldC = colC;
    	 int h_w = padH * padW;
    //	 int strideA = outC*inC;
    //	 int strideB = batchSize*inC;
    //	 int strideC = outC*batchSize;
    //	 int ldA = outC;
    //	 int ldB = inC;
    //	 int ldC = batchSize;
    //
    //	 cuComplex c_mnp;
    //	 for (int p = 0; p < h_w; ++p)
    //	 {
    //	 		for (int m = 0; m < outC; ++m)
    //	  	  	{
    //	    		for (int n = 0; n < batchSize; ++n)
    //	    		{
    //	      	  		c_mnp.x = 0;
    //	      	  	  	c_mnp.y = 0;
    //	     	  	  	for (int k = 0; k < inC; ++k)
    //	      	  	  	{
    //	    				c_mnp.x += h_kerPadded[m + k*ldA + p*strideA].x * h_padded[k + n*ldB + p*strideB].x; //- h_imgs[m + k*ldA + p*strideA].y * h_dftMat[k + n*ldB + p*strideB].y;
    //	    				c_mnp.y += h_kerPadded[m + k*ldA + p*strideA].y * h_padded[k + n*ldB + p*strideB].x; //+ h_imgs[m + k*ldA + p*strideA].x * h_dftMat[k + n*ldB + p*strideB].y;
    //	    				h_C[m + n*ldC + p*strideC].x = c_mnp.x;
    //	    				h_C[m + n*ldC + p*strideC].y = c_mnp.y;
    //	    				printf("Kernel and Img Vals (%f, %f) \n", h_kerPadded[m + k*ldA + p*strideA].x, h_padded[k + n*ldB + p*strideB].x);
    ////	    				printf("(%f,%f) \n", c_mnp.x,c_mnp.y);
    //	      	  	  	}
    //	    		}
    //	  	  	 }
    //	 }
    
    
    
    
    	// Stride has to be the size of the 2D matrix, same stuff as 3D indexing
    //	int strideA = batchSize*inC, strideB = inC*outC, strideC = padH*padW;
    
    	cublasStat = cublasCreate(&handle);
    
    	if(cublasStat != CUBLAS_STATUS_SUCCESS)
    	{
    		printf("Cublas Error\n");
    	}
    
    	// Do here to get updated memory locations
    	cuComplex const *imgsPtr = d_ImgsSpectral;
    	cuComplex const *kerPtr = d_KernelSpectral;
    //
    //	cublasStat = cublasCgemmStridedBatched(handle,CUBLAS_OP_N,CUBLAS_OP_N,
    //													  padH, padW, padH,
    //													  & alpha,
    //													  kerPtr, padW, padH*padW,
    //													  imgsPtr, padW, padH*padW,
    //													  & beta,
    //													  d_C, padW, padH*padW,
    //													  h_w*outC);
    
    //	 Runs but gives incorrect
    	cublasStat = cublasCgemmStridedBatched(handle,CUBLAS_OP_N,CUBLAS_OP_N,
    												  outC, batchSize, inC,
    												  & alpha,
    												  kerPtr, outC, inC*outC,
    												  imgsPtr, inC, inC*batchSize,
    												  & beta,
    												  d_C, outC, batchSize*outC,
    												  padH*padW);
    
    	if(cublasStat != CUBLAS_STATUS_SUCCESS)
    	{
    		printf("Cublas Error\n");
    	}
    
    	cudaMemcpy(h_C,d_C,batchSize * outC * padH * padW * sizeof(cuComplex),cudaMemcpyDeviceToHost);
    	printf("Cublas Batch Res \n");
    	printArr3D(h_C, padH, padW, batchSize*outC);
    
    	printf("Freeing Memory\n");
    	cudaFree(d_imgs);
    	cudaFree(d_ker);
    	cudaFree(d_padded);
    	cudaFree(d_kerPadded);
    	free(h_imgs);
    	free(h_ker);
    	free(h_padded);
    	free(h_kerPadded);
    	cudaDeviceSynchronize();
    }
    
    int main()
    {
    	int batchSize = 1;
    	int inC = 3;
    	int outC = 1;
    	int imgH = 4;
    	int imgW = 4;
    	int kerH = 3;
    	int kerW = 3;
    
    	FFT_Conv_Batch(batchSize, imgH, imgW, kerH, kerW, inC, outC);
    }