One more addition to this thread… Below is a mex program for calling generalized sgemm. It can be called from matlab with: C=sgemm_cu(transa,transb,alpha,beta,A,B,C) and it will calculate C=alphaAB + beta*C, with A or B transposed according to transa or transb. I’ve chosen to use integer flags for these, so 0=no transpose, 1 (or anything else) = transpose.
At the moment, this mex file is offering me better than 20% improvement over ordinary matlab in my particular application which employs ca. 3000X3000 matrices. I need to go to greater dimensions soon, however so I expect this mex file will be well worthwhile - if single precision is sufficient for the calculation.
I post this, thinking at least some people will find it useful.
#include "mex.h"
#include "cublas.h"
void convert_double2float( double *input_double, float *output_float,int Ntot)
{
int i;
for (i = 0; i < Ntot; i++)
{
output_float[i] = (float) input_double[i];
}
}
void convert_float2double( float *input_float, double *output_double,int Ntot)
{
int i;
for (i = 0; i < Ntot; i++)
{
/* printf("input=%f\n",input_float[i]); */
output_double[i] = (double) input_float[i];
/* printf("ouput=%f\n",output_double[i]); */
}
}
/* sgemm_cu.cu - Gateway function for subroutine sgemm
function [C]=sgemm_cu(transa,transb,alpha,beta,A,B,C)
transa,transb = 0/1 for no transpose/transpose of A,B
*/
void mexFunction( int nlhs, mxArray *plhs[],
int nrhs, const mxArray *prhs[])
{
cublasStatus status;
int M,K,L,N,MM,NN,KK;
int dims0[2];
int ta,tb;
float alpha,beta;
double *A,*B,*C,*CC;
float *a,*b,*c;
float *ga,*gb,*gc;
char transa,transb;
cublasStatus retStatus;
if (nrhs != 7) {
mexErrMsgTxt("sgemm requires 7 input arguments");
} else if (nlhs != 1) {
mexErrMsgTxt("sgemm requires 1 output argument");
}
ta = (int) mxGetScalar(prhs[0]);
tb = (int) mxGetScalar(prhs[1]);
alpha = (float) mxGetScalar(prhs[2]);
beta = (float) mxGetScalar(prhs[3]);
M = mxGetM(prhs[4]); /* gets number of rows of A */
K = mxGetN(prhs[4]); /* gets number of columns of A */
L = mxGetM(prhs[5]); /* gets number of rows of B */
N = mxGetN(prhs[5]); /* gets number of columns of B */
if (ta == 0) {
transa='n';
MM=M;
KK=K;
} else {
transa='t';
MM=K;
KK=M;
}
if (tb == 0) {
transb='n';
NN=N;
} else {
transb='t';
NN=L;
}
/* printf("transa=%c\n",transa);
printf("transb=%c\n",transb);
printf("alpha=%f\n",alpha);
printf("beta=%f\n",beta); */
/* Left hand side matrix set up */
dims0[0]=MM;
dims0[1]=NN;
plhs[0] = mxCreateNumericArray(2,dims0,mxDOUBLE_CLASS,mxREAL);
CC = mxGetPr(plhs[0]);
/* Matrix 1 */
A = mxGetPr(prhs[4]);
/* Matrix 2 */
B = mxGetPr(prhs[5]);
/* Matrix 3 */
C = mxGetPr(prhs[6]);
/* Allocating working array on host */
a = (float*) mxMalloc(sizeof(float)*M*K);
b = (float*) mxMalloc(sizeof(float)*N*L);
c = (float*) mxMalloc(sizeof(float)*MM*NN);
convert_double2float(A,a,M*K);
convert_double2float(B,b,N*L);
convert_double2float(C,c,MM*NN);
/* STARTUP CUBLAS */
retStatus = cublasInit();
// test for error
retStatus = cublasGetError ();
if (retStatus != CUBLAS_STATUS_SUCCESS) {
printf("CUBLAS: an error occurred in cublasInit\n");
} else {
printf("");
}
/* ALLOCATE SPACE ON THE GPU AND COPY a INTO IT */
cublasAlloc (M*K, sizeof(float), (void**)&ga);
// test for error
retStatus = cublasGetError ();
if (retStatus != CUBLAS_STATUS_SUCCESS) {
printf("CUBLAS: an error occurred in cublasAlloc\n");
} else {
printf("");
}
retStatus = cublasSetMatrix (M, K, sizeof(float),
a, M, (void*)ga, M);
/* SAME FOR B, C */
cublasAlloc (L*N, sizeof(float), (void**)&gb);
retStatus = cublasSetMatrix (L, N, sizeof(float),
b, L, (void*)gb, L);
cublasAlloc (NN*MM, sizeof(float), (void**)&gc);
retStatus = cublasSetMatrix (MM, NN, sizeof(float),
c, MM, (void*)gc, MM);
/* printf("Op(A) has No. rows = %i\n",MM);
printf("Op(B) has No. cols = %i\n",NN);
printf("Op(A) has No. cols = %i\n",KK);
printf("A has leading dimension = %i\n",M);
printf("B has leading dimension = %i\n",L);
printf("C has leading dimension = %i\n",MM); */
/* READY TO CALL SGEMM */
(void) cublasSgemm (transa,transb,MM,NN,KK,alpha,ga,M,gb,L,beta,gc,MM);
status = cublasGetError();
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf (stderr, "!!!! kernel execution error.\n");
}
/* NOW COPY THE RESULTING gc ON THE GPU TO THE LOCAL c */
retStatus = cublasGetMatrix (MM, NN, sizeof(float),
gc, MM, c, MM);
if (retStatus != CUBLAS_STATUS_SUCCESS) {
printf("CUBLAS: an error occurred in cublasGetMatrix\n");
} else {
printf("");
}
/* FREE UP GPU MEMORY AND SHUTDOWN (OPTIONAL?) */
cublasFree (ga);
cublasFree (gb);
cublasFree (gc);
cublasShutdown();
convert_float2double(c,CC,MM*NN);
}