Hi everyone,
I was eager to try the new CUDA Graph API but ran into an issue with cuBLAS.
Everything works fine until I call cublasSgemmStridedBatched with batchCount = 1.
The following code works fine with batchCount=2 or when batchCount=1 and I don’t capture the stream.
#include <cuda_runtime.h>
#include <helper_cuda.h>
#include <vector>
#include <cublas_v2.h>
int main(int argc, char **argv) {
cudaStream_t stream1;
cudaGraph_t graph;
cublasHandle_t handle;
cublasStatus_t status = cublasCreate(&handle);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr, "!!!! CUBLAS initialization error\n");
}
checkCudaErrors(cudaStreamCreate(&stream1));
float *a = NULL, *b = NULL, *c = NULL;
size_t size = 1 * 256 * 32; // number of elements to reduce
checkCudaErrors(cudaMalloc(&a, sizeof(float) * size));
checkCudaErrors(cudaMalloc(&b, sizeof(float) * size));
checkCudaErrors(cudaMalloc(&c, sizeof(float) * size));
float alpha = 1.0f;
float beta = 1.0f;
const int lda = 256;
const int ldb = 32;
const int ldc = 256;
const int strideA = 8192;
const int strideB = 1024;
const int strideC = 8192;
const int batchCount = 2;
cublasSetStream(handle, stream1);
status = cublasSgemmStridedBatched(handle, CUBLAS_OP_N, CUBLAS_OP_N, 256, 32, 32, &alpha, a, lda, strideA, b, ldb, strideB, &beta, c, ldc, strideC, batchCount);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr, "!!!! kernel execution error.\n");
}
status = cublasSgemmStridedBatched(handle, CUBLAS_OP_N, CUBLAS_OP_N, 256, 32, 32, &alpha, a, lda, strideA, b, ldb, strideB, &beta, c, ldc, strideC, batchCount);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr, "!!!! kernel execution 2 error.\n");
}
checkCudaErrors(cudaStreamBeginCapture(stream1));
status = cublasSgemmStridedBatched(handle, CUBLAS_OP_N, CUBLAS_OP_N, 256, 32, 32, &alpha, a, lda, strideA, b, ldb, strideB, &beta, c, ldc, strideC, batchCount);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr, "!!!! kernel capture error.\n");
}
status = cublasSgemmStridedBatched(handle, CUBLAS_OP_N, CUBLAS_OP_N, 256, 32, 32, &alpha, a, lda, strideA, b, ldb, strideB, &beta, c, ldc, strideC, batchCount);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr, "!!!! kernel capture 2 error.\n");
}
checkCudaErrors(cudaStreamEndCapture(stream1, &graph));
checkCudaErrors(cudaGraphDestroy(graph));
checkCudaErrors(cudaStreamDestroy(stream1));
checkCudaErrors(cudaFree(a));
checkCudaErrors(cudaFree(b));
checkCudaErrors(cudaFree(c));
return EXIT_SUCCESS;
}
The interesting thing is that the capture doesn’t fail until the second call, the first call is successfull.
Does anyone have insight on what I could be doing wrong or is this just a bug?
Thanks,
Felipe