#include #include #include #include #include "sgemm.cu" void fill(unsigned int n, unsigned int m, float *input); void zero(unsigned int n, unsigned int m, float *input); void printMatrix(unsigned int n, unsigned int m, float *input); int main(int argc, char **argv) { unsigned int N = 1024, size = N * N * sizeof(float); double aN = (double)N; float *d_a, *h_a, *d_b, *h_b, *d_c, *h_c; double time = 0.0; aN = aN * aN * (2.0 * aN - 1.0); cudaMalloc((void **)&d_a, size); cudaMalloc((void **)&d_b, size); cudaMalloc((void **)&d_c, size); h_a = (float *)malloc(size); h_b = (float *)malloc(size); h_c = (float *)malloc(size); fill(N, N, h_a); fill(N, N, h_b); zero(N, N, h_c); printf("Before first memcpy: cuError: %s\n", cudaGetErrorString(cudaGetLastError())); cudaMemcpy(d_a, h_a, size, cudaMemcpyHostToDevice); cudaMemcpy(d_b, h_b, size, cudaMemcpyHostToDevice); cudaMemcpy(d_c, h_c, size, cudaMemcpyHostToDevice); printf("After first memcpy: cuError: %s\n", cudaGetErrorString(cudaGetLastError())); time = MPI_Wtime(); sgemm(N, N, N, 1.0f, d_a, d_b, 1.0f, d_c); time = MPI_Wtime() - time; printf("Before second memcpy: cuError: %s\n", cudaGetErrorString(cudaGetLastError())); cudaMemcpy(h_a, d_a, size, cudaMemcpyDeviceToHost); cudaMemcpy(h_b, d_b, size, cudaMemcpyDeviceToHost); cudaMemcpy(h_c, d_c, size, cudaMemcpyDeviceToHost); printf("After second memcpy: cuError: %s\n", cudaGetErrorString(cudaGetLastError())); if(N <= 32) { printMatrix(N, N, h_a); printMatrix(N, N, h_b); printMatrix(N, N, h_c); } cudaFree(d_a); cudaFree(d_b); cudaFree(d_c); free(h_a); free(h_b); free(h_c); printf("Time = %f\n, Gflops = %f\n", time, aN / (time * 1000000000)); return 0; } void fill(unsigned int n, unsigned int m, float *input) { unsigned int i = 0; for(i = 0; i < n * m; i++) { input[i] = (float)i; } } void zero(unsigned int n, unsigned int m, float *input) { unsigned int i = 0; for(i = 0; i < n * m; i++) { input[i] = 0; } } void printMatrix(unsigned int n, unsigned int m, float *input) { unsigned int i = 0, j = 0; for(i = 0; i < n; i++) { for(j = 0; j < m; j++) { printf("%2.0f ", input[i * m + j]); } printf("\n"); } printf("\n"); }