Where can I find working examples for the new cuBLASLt library?

I’m trying to use the new library cuBLASLt released with CUDA 10.1.

I can’t get it working so I’m looking for working examples which I could modify to match my needs. There are extracts in the documentation but only a few sub-routines are shown not the full program. The CUDA samples don’t have an example too (even on github).

Is there somewhere a full example I can use?

Looking at the programming guide in section Using the cuBLASLt API under subsection 3.2.1. Single Precision GEMM, you’ll see an example that is nearly a drop-in replacement for cublasSgemm. That being said, you can start with the CUDA example in <samples_location>7_CUDALibraries/simpleCUBLAS, you can replace the cublasSgemm call with the 3.2.1 example. See below…

Notice that you can’t run cublasSgemm without making a few type changes. For simplicity, workspace=nullptr and workspaceSize=0.

/*
 * Copyright 1993-2017 NVIDIA Corporation.  All rights reserved.
 *
 * NOTICE TO USER:
 *
 * This source code is subject to NVIDIA ownership rights under U.S. and
 * international Copyright laws.  Users and possessors of this source code
 * are hereby granted a nonexclusive, royalty-free license to use this code
 * in individual and commercial software.
 *
 * NVIDIA MAKES NO REPRESENTATION ABOUT THE SUITABILITY OF THIS SOURCE
 * CODE FOR ANY PURPOSE.  IT IS PROVIDED "AS IS" WITHOUT EXPRESS OR
 * IMPLIED WARRANTY OF ANY KIND.  NVIDIA DISCLAIMS ALL WARRANTIES WITH
 * REGARD TO THIS SOURCE CODE, INCLUDING ALL IMPLIED WARRANTIES OF
 * MERCHANTABILITY, NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
 * IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL,
 * OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS
 * OF USE, DATA OR PROFITS,  WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
 * OR OTHER TORTIOUS ACTION,  ARISING OUT OF OR IN CONNECTION WITH THE USE
 * OR PERFORMANCE OF THIS SOURCE CODE.
 *
 * U.S. Government End Users.   This source code is a "commercial item" as
 * that term is defined at  48 C.F.R. 2.101 (OCT 1995), consisting  of
 * "commercial computer  software"  and "commercial computer software
 * documentation" as such terms are  used in 48 C.F.R. 12.212 (SEPT 1995)
 * and is provided to the U.S. Government only as a commercial end item.
 * Consistent with 48 C.F.R.12.212 and 48 C.F.R. 227.7202-1 through
 * 227.7202-4 (JUNE 1995), all U.S. Government End Users acquire the
 * source code with only those rights set forth herein.
 *
 * Any use of this source code in individual and commercial software must
 * include, in the user documentation and internal comments to the code,
 * the above Disclaimer and U.S. Government End Users Notice.
 */

/* This example demonstrates how to use the CUBLAS library
 * by scaling an array of floating-point values on the device
 * and comparing the result to the same operation performed
 * on the host.
 */

/* Includes, system */
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

/* Includes, cuda */
#include <cublasLt.h>
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include "helper_cuda.h"

/* Matrix size */
#define N (4096)

cublasStatus_t
LtSgemm(cublasLtHandle_t ltHandle,
       cublasOperation_t transa,
       cublasOperation_t transb,
       int m,
       int n,
       int k,
       const float *alpha, /* host pointer */
       const float *A,
       int lda,
       const float *B,
       int ldb,
       const float *beta, /* host pointer */
       float *C,
       int ldc,
       void *workspace,
       size_t workspaceSize) {
   cublasStatus_t status = CUBLAS_STATUS_SUCCESS;

   cublasLtMatmulDesc_t operationDesc = NULL;
   cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
   cublasLtMatmulPreference_t preference = NULL;

   int returnedResults                             = 0;
   cublasLtMatmulHeuristicResult_t heuristicResult = {};

   // Create operation descriptor; see cublasLtMatmulDescAttributes_t
   // for details about defaults; here we just set the transforms for
   // A and B.
   status = cublasLtMatmulDescCreate(&operationDesc, CUDA_R_32F);
   if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
   status = cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
   if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
   status = cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
   if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;

   // Create matrix descriptors. Not setting any extra attributes.
   status = cublasLtMatrixLayoutCreate(
       &Adesc, CUDA_R_32F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
   if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
   status = cublasLtMatrixLayoutCreate(
       &Bdesc, CUDA_R_32F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
   if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
   status = cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_32F, m, n, ldc);
   if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;

   // Create preference handle; In general, extra attributes can be
   // used here to disable tensor ops or to make sure algo selected
   // will work with badly aligned A, B, C. However, for simplicity
   // here we assume A,B,C are always well aligned (e.g., directly
   // come from cudaMalloc)
   status = cublasLtMatmulPreferenceCreate(&preference);
   if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
   status = cublasLtMatmulPreferenceSetAttribute(
       preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
   if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;

   // We just need the best available heuristic to try and run matmul.
   // There is no guarantee that this will work. For example, if A is
   // badly aligned, you can request more (e.g. 32) algos and try to
   // run them one by one until something works.
   status = cublasLtMatmulAlgoGetHeuristic(
       ltHandle, operationDesc, Adesc, Bdesc, Cdesc, Cdesc, preference, 1, &heuristicResult, &returnedResults);
   if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;

   if (returnedResults == 0) {
       status = CUBLAS_STATUS_NOT_SUPPORTED;
       goto CLEANUP;
   }

   status = cublasLtMatmul(ltHandle,
                           operationDesc,
                           alpha,
                           A,
                           Adesc,
                           B,
                           Bdesc,
                           beta,
                           C,
                           Cdesc,
                           C,
                           Cdesc,
                           &heuristicResult.algo,
                           workspace,
                           workspaceSize,
                           0);

CLEANUP:
   // Descriptors are no longer needed as all GPU work was already
   // enqueued.
   if (preference) cublasLtMatmulPreferenceDestroy(preference);
   if (Cdesc) cublasLtMatrixLayoutDestroy(Cdesc);
   if (Bdesc) cublasLtMatrixLayoutDestroy(Bdesc);
   if (Adesc) cublasLtMatrixLayoutDestroy(Adesc);
   if (operationDesc) cublasLtMatmulDescDestroy(operationDesc);
   return status == CUBLAS_STATUS_SUCCESS ? static_cast<cublasStatus_t>(0) : static_cast<cublasStatus_t>(1);
}

/* Host implementation of a simple version of sgemm */
static void simple_sgemm(int n, float alpha, const float *A, const float *B,
                         float beta, float *C) {
  int i;
  int j;
  int k;

  for (i = 0; i < n; ++i) {
    for (j = 0; j < n; ++j) {
      float prod = 0;

      for (k = 0; k < n; ++k) {
        prod += A[k * n + i] * B[j * n + k];
      }

      C[j * n + i] = alpha * prod + beta * C[j * n + i];
    }
  }
}

/* Main */
int main(int argc, char **argv) {
  cublasStatus_t status;
  float *h_A;
  float *h_B;
  float *h_C;
  float *h_C_ref;
  float *d_A = 0;
  float *d_B = 0;
  float *d_C = 0;
  float alpha = 1.0f;
  float beta = 0.0f;
  int n2 = N * N;
  int i;
  float error_norm;
  float ref_norm;
  float diff;
  cublasLtHandle_t handle;

  int dev = findCudaDevice(argc, (const char **)argv);

  if (dev == -1) {
    return EXIT_FAILURE;
  }

  /* Initialize CUBLAS */
  printf("simpleCUBLAS test running..\n");

  status = cublasLtCreate(&handle);

  if (status != CUBLAS_STATUS_SUCCESS) {
    fprintf(stderr, "!!!! CUBLAS initialization error\n");
    return EXIT_FAILURE;
  }

  /* Allocate host memory for the matrices */
  h_A = reinterpret_cast<float *>(malloc(n2 * sizeof(h_A[0])));

  if (h_A == 0) {
    fprintf(stderr, "!!!! host memory allocation error (A)\n");
    return EXIT_FAILURE;
  }

  h_B = reinterpret_cast<float *>(malloc(n2 * sizeof(h_B[0])));

  if (h_B == 0) {
    fprintf(stderr, "!!!! host memory allocation error (B)\n");
    return EXIT_FAILURE;
  }

  h_C = reinterpret_cast<float *>(malloc(n2 * sizeof(h_C[0])));

  if (h_C == 0) {
    fprintf(stderr, "!!!! host memory allocation error (C)\n");
    return EXIT_FAILURE;
  }

  /* Fill the matrices with test data */
  for (i = 0; i < n2; i++) {
    h_A[i] = rand() / static_cast<float>(RAND_MAX);
    h_B[i] = rand() / static_cast<float>(RAND_MAX);
    h_C[i] = rand() / static_cast<float>(RAND_MAX);
  }

  /* Allocate device memory for the matrices */
  if (cudaMalloc(reinterpret_cast<void **>(&d_A), n2 * sizeof(d_A[0])) !=
      cudaSuccess) {
    fprintf(stderr, "!!!! device memory allocation error (allocate A)\n");
    return EXIT_FAILURE;
  }

  if (cudaMalloc(reinterpret_cast<void **>(&d_B), n2 * sizeof(d_B[0])) !=
      cudaSuccess) {
    fprintf(stderr, "!!!! device memory allocation error (allocate B)\n");
    return EXIT_FAILURE;
  }

  if (cudaMalloc(reinterpret_cast<void **>(&d_C), n2 * sizeof(d_C[0])) !=
      cudaSuccess) {
    fprintf(stderr, "!!!! device memory allocation error (allocate C)\n");
    return EXIT_FAILURE;
  }

  /* Initialize the device matrices with the host matrices */
  status = cublasSetVector(n2, sizeof(h_A[0]), h_A, 1, d_A, 1);

  if (status != CUBLAS_STATUS_SUCCESS) {
    fprintf(stderr, "!!!! device access error (write A)\n");
    return EXIT_FAILURE;
  }

  status = cublasSetVector(n2, sizeof(h_B[0]), h_B, 1, d_B, 1);

  if (status != CUBLAS_STATUS_SUCCESS) {
    fprintf(stderr, "!!!! device access error (write B)\n");
    return EXIT_FAILURE;
  }

  status = cublasSetVector(n2, sizeof(h_C[0]), h_C, 1, d_C, 1);

  if (status != CUBLAS_STATUS_SUCCESS) {
    fprintf(stderr, "!!!! device access error (write C)\n");
    return EXIT_FAILURE;
  }

  /* Performs operation using plain C code */
  simple_sgemm(N, alpha, h_A, h_B, beta, h_C);
  h_C_ref = h_C;

// ******* REMOVE ********
  /* Performs operation using cublas */
//  status = cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, N, N, N, &alpha, d_A,
//                       N, d_B, N, &beta, d_C, N);
  // ******* REMOVE ********

  status = LtSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, N, N, N, &alpha, d_A,
                         N, d_B, N, &beta, d_C, N, nullptr, 0);

  if (status != CUBLAS_STATUS_SUCCESS) {
    fprintf(stderr, "!!!! kernel execution error.\n");
    return EXIT_FAILURE;
  }

  /* Allocate host memory for reading back the result from device memory */
  h_C = reinterpret_cast<float *>(malloc(n2 * sizeof(h_C[0])));

  if (h_C == 0) {
    fprintf(stderr, "!!!! host memory allocation error (C)\n");
    return EXIT_FAILURE;
  }

  /* Read the result back */
  status = cublasGetVector(n2, sizeof(h_C[0]), d_C, 1, h_C, 1);

  if (status != CUBLAS_STATUS_SUCCESS) {
    fprintf(stderr, "!!!! device access error (read C)\n");
    return EXIT_FAILURE;
  }

  /* Check result against reference */
  error_norm = 0;
  ref_norm = 0;

  for (i = 0; i < n2; ++i) {
    diff = h_C_ref[i] - h_C[i];
    error_norm += diff * diff;
    ref_norm += h_C_ref[i] * h_C_ref[i];
  }

  error_norm = static_cast<float>(sqrt(static_cast<double>(error_norm)));
  ref_norm = static_cast<float>(sqrt(static_cast<double>(ref_norm)));

  if (fabs(ref_norm) < 1e-7) {
    fprintf(stderr, "!!!! reference norm is 0\n");
    return EXIT_FAILURE;
  }

  /* Memory clean up */
  free(h_A);
  free(h_B);
  free(h_C);
  free(h_C_ref);

  if (cudaFree(d_A) != cudaSuccess) {
    fprintf(stderr, "!!!! memory free error (A)\n");
    return EXIT_FAILURE;
  }

  if (cudaFree(d_B) != cudaSuccess) {
    fprintf(stderr, "!!!! memory free error (B)\n");
    return EXIT_FAILURE;
  }

  if (cudaFree(d_C) != cudaSuccess) {
    fprintf(stderr, "!!!! memory free error (C)\n");
    return EXIT_FAILURE;
  }

  /* Shutdown */
  status = cublasLtDestroy(handle);

  if (status != CUBLAS_STATUS_SUCCESS) {
    fprintf(stderr, "!!!! shutdown error (A)\n");
    return EXIT_FAILURE;
  }

  if (error_norm / ref_norm < 1e-3f) {
    printf("simpleCUBLAS test passed.\n");
    exit(EXIT_SUCCESS);
  } else {
    printf("simpleCUBLAS test failed.\n");
    exit(EXIT_FAILURE);
  }
}

WHile this example compiles, it doesn’t appear to work on the first try. It hangs on a V100:

GPU Device 0: “Tesla V100-PCIE-16GB” with compute capability 7.0

simpleCUBLAS test running…

I think you have to wait a very long time for the CPU verification routine to finish. Study the code.

Try reducing N in the code to confirm this if you don’t feel like waiting.

By way of example, on my not-particularly-fast-CPU system, with N=256 the code executes “immediately”.
With N = 1024, the code takes ~9 seconds.
With N = 2048, the code takes 3.5 minutes.

Nearly all of this time is spent in the CPU code.

I don’t know how long it takes with N=4096, but I’m skeptical of claims of a hang.

That’s right. I noticed it right before seeing this, but the CPU version does take extremely long. At first I thought it was a fairly small matrix, but it’s not.

And by extremely long, I’m talking about minutes compared to < 1 second on the GPU.

It’s not hanging, it’s just the CPU code is very slow. Also, for anyone using this example, keep in mind that it will not use tensor cores by default. I’m still working through getting an example working that uses them, but no luck so far.

Here is an example that uses tensor operations. There is a lot going on in the code so please study it before jumping in. It provides a couple of knobs to turn to prove to yourself it’s working correctly.

There are two scenarios that use tensor operation with complex half precision.

  1. CUDA_C_16F, CUDA_C_16F, CUDA_C_16F, CUDA_C_32F, CUDA_C_32F
  2. CUDA_C_16F, CUDA_C_16F, CUDA_C_32F, CUDA_C_32F, CUDA_C_32F

First, set IDENTITY 1 and PRINT 1. This will create 2 input identity matrices, in matrix A and B. The result should print a 16x16 identity matrix. You’ll notice they are pairs, to show real and imaginary parts.

Next, set PRINT 0 and you can test multiple square matrices with the identity test.

Lastly, set IDENTITY 0 and PRINT 0 and you will test multiple square and non-square matrices with randomly generated data.

You should notice another define variable TIME_TRANSFORM. When matrix A and B are generated with random numbers, they were generated with an interleaved layout. Meaning data is stored [real, imaginary, real, imaginary, …]. In order to utilize Tensor Cores the data must be in planar layout. Meaning data is stored [real, real, real, … (half way), imaginary, imaginary, imaginary].

When TIME_TRANSFORM 1 then the time taken to transform A and B to planar layout, perform matrix multiplication, and the time taken to transform C from planar to interleaved layout is calculated. When TIME_TRANSFORM 0 only matrix multiplication is profiled.

That should cover everything and I hope it helps.

/*
 * Copyright 1993-2017 NVIDIA Corporation.  All rights reserved.
 *
 * NOTICE TO USER:
 *
 * This source code is subject to NVIDIA ownership rights under U.S. and
 * international Copyright laws.  Users and possessors of this source code
 * are hereby granted a nonexclusive, royalty-free license to use this code
 * in individual and commercial software.
 *
 * NVIDIA MAKES NO REPRESENTATION ABOUT THE SUITABILITY OF THIS SOURCE
 * CODE FOR ANY PURPOSE.  IT IS PROVIDED "AS IS" WITHOUT EXPRESS OR
 * IMPLIED WARRANTY OF ANY KIND.  NVIDIA DISCLAIMS ALL WARRANTIES WITH
 * REGARD TO THIS SOURCE CODE, INCLUDING ALL IMPLIED WARRANTIES OF
 * MERCHANTABILITY, NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
 * IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL,
 * OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS
 * OF USE, DATA OR PROFITS,  WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
 * OR OTHER TORTIOUS ACTION,  ARISING OUT OF OR IN CONNECTION WITH THE USE
 * OR PERFORMANCE OF THIS SOURCE CODE.
 *
 * U.S. Government End Users.   This source code is a "commercial item" as
 * that term is defined at  48 C.F.R. 2.101 (OCT 1995), consisting  of
 * "commercial computer  software"  and "commercial computer software
 * documentation" as such terms are  used in 48 C.F.R. 12.212 (SEPT 1995)
 * and is provided to the U.S. Government only as a commercial end item.
 * Consistent with 48 C.F.R.12.212 and 48 C.F.R. 227.7202-1 through
 * 227.7202-4 (JUNE 1995), all U.S. Government End Users acquire the
 * source code with only those rights set forth herein.
 *
 * Any use of this source code in individual and commercial software must
 * include, in the user documentation and internal comments to the code,
 * the above Disclaimer and U.S. Government End Users Notice.
 */

/* This example demonstrates how to use the CUBLAS library
 * by scaling an array of floating-point values on the device
 * and comparing the result to the same operation performed
 * on the host.
 */

/* Includes, system */
#include <cstdio>

/* Includes, cuda */
#include <cublasLt.h>
#include <cuda_runtime.h>

#include <thrust/complex.h>
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include <thrust/random.h>

#include "helper_cuda.h"

#define PRINT 0
#define IDENTITY 0
#define TIME_TRANSFORM 1
#define SCENARIO 0

auto constexpr kernelRepeats = 10;
auto constexpr threadsPerBlock = 1024;

#if SCENARIO == 0 // CUDA_C_16F, CUDA_C_16F, CUDA_C_16F, CUDA_C_32F, CUDA_C_32F
auto constexpr cudaTypeI = CUDA_C_16F;
typedef half2 dataTypeI;
auto constexpr cudaTypeO = CUDA_C_16F;
typedef half2 dataTypeO;
//auto constexpr cudaTypeS = CUDA_C_32F;
typedef thrust::complex<float> dataTypeS;
auto constexpr cudaTypeCom = CUDA_C_32F;

#elif SCENARIO == 1 // CUDA_C_16F, CUDA_C_16F, CUDA_C_32F, CUDA_C_32F, CUDA_C_32F
auto constexpr cudaTypeI = CUDA_C_16F;
typedef half2 dataTypeI;
auto constexpr cudaTypeO = CUDA_C_32F;
typedef thrust::complex<float> dataTypeO;
//auto constexpr cudaTypeS = CUDA_C_32F;
typedef thrust::complex<float> dataTypeS;
auto constexpr cudaTypeCom = CUDA_C_32F;
#endif

struct GenRand {
	__device__
	dataTypeI operator ()( int const & idx ) {
		dataTypeI result;
		thrust::default_random_engine randEng;
		thrust::uniform_real_distribution<float> uniDist;
		randEng.discard( idx );
		result.x = __float2half( uniDist( randEng ) );
		result.y = __float2half( uniDist( randEng ) );
		return ( result );
	}
};

struct setIdentity {
	int const m;
	setIdentity( int const & _m ) :
			m( _m ) {
	}
	__device__
	dataTypeI operator ()( int const & idx ) {
		dataTypeI result;
		result.x = __float2half( 0.0f );
		result.y = __float2half( 0.0f );
		int const diagIdx = ( m + 1 );	// Since we are using complex half.
		if ( idx % ( diagIdx ) == 0 ) result.x = __float2half( 1.0f );
		return ( result );
	}
};

template<typename Pointer>
__global__ void __launch_bounds__(threadsPerBlock) checkIdentity( int const n, int const m, dataTypeO const * d_C, Pointer d_p ) {
	for ( int tid = blockIdx.x * blockDim.x + threadIdx.x; tid < n; tid += blockDim.x * gridDim.x ) {
		int const diagIdx = m + 1;
#if SCENARIO == 0
		if ( tid % ( diagIdx ) == 0 ) { // If thread index is on the diagonal
			if ( __hge( fabsf( d_C[tid].x - __float2half(1.0f) ), __float2half(1e-7f) ) ) *d_p = false; // abs( d_C - 1.0f ) > 1e-7
		} else if ( __hge( d_C[tid].x, __float2half(1e-7f) ) ) *d_p = false;
#elif SCENARIO == 1
		if ( tid % ( diagIdx ) == 0 ) { // If thread index is on the diagonal
			if ( fabsf( d_C[tid].real( ) - 1.0f ) > 1e-7f ) *d_p = false;
		} else if ( d_C[tid].real( ) > 1e-7f ) *d_p = false;
#endif
	}
};

void LtSgemm(
		cublasLtHandle_t ltHandle,
		cublasOperation_t transa,
		cublasOperation_t transb,
		int const & m,
		int const & n,
		int const & k,
		dataTypeS const *alpha,
		int const & sizeA,
		dataTypeI const *A,
		int const & lda,
		int const & sizeB,
		dataTypeI const *B,
		int const & ldb,
		dataTypeS const *beta,
		int const & sizeC,
		dataTypeO *C,
		int const & ldc,
		void *workSpace,
		size_t workSpaceSize ) {

	// The offset should start right after real data
	size_t planarOffsetA = ( sizeA * sizeof(dataTypeI) ) / 2;
	size_t planarOffsetB = ( sizeB * sizeof(dataTypeI) ) / 2;
	size_t planarOffsetC = ( sizeC * sizeof(dataTypeO) ) / 2;

	cublasLtMatmulDesc_t operationDesc = nullptr;
	cublasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Cdesc = nullptr;

	cublasLtMatmulPreference_t preference = nullptr;

	dataTypeI * Atransform, *Btransform;
	dataTypeO * Ctransform;
	cublasLtMatrixTransformDesc_t transformDescI = nullptr, transformDescO = nullptr;
	cublasLtMatrixLayout_t AtransformDesc = nullptr, BtransformDesc = nullptr, CtransformDesc = nullptr;

	// Allocate memory for transformed matrix
	checkCudaErrors( cudaMalloc( reinterpret_cast<void**>(&Atransform), sizeA * sizeof(dataTypeI) ) );
	checkCudaErrors( cudaMalloc( reinterpret_cast<void**>(&Btransform), sizeB * sizeof(dataTypeI) ) );
	checkCudaErrors( cudaMalloc( reinterpret_cast<void**>(&Ctransform), sizeC * sizeof(dataTypeO) ) );

	// Create preference handle; In general, extra attributes can be
	// used here to disable tensor ops or to make sure algo selected
	// will work with badly aligned A, B, C. However, for simplicity
	// here we assume A,B,C are always well aligned (e.g., directly
	// come from cudaMalloc)
	checkCudaErrors( cublasLtMatmulPreferenceCreate( &preference ) );
	checkCudaErrors(
			cublasLtMatmulPreferenceSetAttribute( preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workSpaceSize, sizeof( workSpaceSize ) ) );

	// Create operation descriptor; see cublasLtMatmulDescAttributes_t
	// for details about defaults; here we just set the transforms for
	// A and B.
	checkCudaErrors( cublasLtMatmulDescCreate( &operationDesc, cudaTypeCom ) );
	checkCudaErrors( cublasLtMatmulDescSetAttribute( operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof( transa ) ) );
	checkCudaErrors( cublasLtMatmulDescSetAttribute( operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof( transa ) ) );

	// Create matrix descriptors for interleaved data. Not setting any extra attributes.
	checkCudaErrors( cublasLtMatrixLayoutCreate( &Adesc, cudaTypeI, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda ) );
	checkCudaErrors( cublasLtMatrixLayoutCreate( &Bdesc, cudaTypeI, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb ) );
	checkCudaErrors( cublasLtMatrixLayoutCreate( &Cdesc, cudaTypeO, m, n, ldc ) );

	// Create transform descriptor to convert interleaved to planar
	checkCudaErrors( cublasLtMatrixTransformDescCreate( &transformDescI, cudaTypeCom ) );
	checkCudaErrors( cublasLtMatrixTransformDescCreate( &transformDescO, cudaTypeCom ) );

	// Create matrix descriptors for planar data. Not setting any extra attributes.
	// Need to double check 3rd parameter
	checkCudaErrors( cublasLtMatrixLayoutCreate( &AtransformDesc, cudaTypeI, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda ) );
	checkCudaErrors( cublasLtMatrixLayoutCreate( &BtransformDesc, cudaTypeI, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb ) );
	checkCudaErrors( cublasLtMatrixLayoutCreate( &CtransformDesc, cudaTypeO, m, n, ldc ) );

	// Configure inputs and outputs to as planar layout
	checkCudaErrors(
			cublasLtMatrixLayoutSetAttribute( AtransformDesc, CUBLASLT_MATRIX_LAYOUT_PLANE_OFFSET, &planarOffsetA, sizeof( planarOffsetA ) ) );
	checkCudaErrors(
			cublasLtMatrixLayoutSetAttribute( BtransformDesc, CUBLASLT_MATRIX_LAYOUT_PLANE_OFFSET, &planarOffsetB, sizeof( planarOffsetB ) ) );
	checkCudaErrors(
			cublasLtMatrixLayoutSetAttribute( CtransformDesc, CUBLASLT_MATRIX_LAYOUT_PLANE_OFFSET, &planarOffsetC, sizeof( planarOffsetC ) ) );

	// Create CUDA event to time the execution time of each algo
	cudaEvent_t startEvent = nullptr, stopEvent = nullptr;
	cudaStream_t stream = nullptr;

#if TIME_TRANSFORM == 0
	checkCudaErrors(
			cublasLtMatrixTransform(
					ltHandle,
					transformDescI,
					alpha,
					A,
					Adesc,
					beta,
					NULL,
					NULL,
					Atransform,
					AtransformDesc,
					stream ) );

	checkCudaErrors(
			cublasLtMatrixTransform(
					ltHandle,
					transformDescI,
					alpha,
					B,
					Bdesc,
					beta,
					nullptr,
					nullptr,
					Btransform,
					BtransformDesc,
					stream ) );
#endif

	checkCudaErrors( cudaEventCreate( &startEvent, cudaEventBlockingSync ) );
	checkCudaErrors( cudaEventCreate( &stopEvent, cudaEventBlockingSync ) );
	checkCudaErrors( cudaEventRecord( startEvent, stream ) );

	for ( int loop = 0; loop < kernelRepeats; loop++ ) {

#if TIME_TRANSFORM == 1
		// Transform interleaved data to planar
		checkCudaErrors(
				cublasLtMatrixTransform(
						ltHandle,
						transformDescI,
						alpha,
						A,
						Adesc,
						beta,
						NULL,
						NULL,
						Atransform,
						AtransformDesc,
						stream ) );

		checkCudaErrors(
				cublasLtMatrixTransform(
						ltHandle,
						transformDescI,
						alpha,
						B,
						Bdesc,
						beta,
						nullptr,
						nullptr,
						Btransform,
						BtransformDesc,
						stream ) );
#endif

		checkCudaErrors(
				cublasLtMatmul(
						ltHandle,
						operationDesc,
						alpha,
						Atransform,
						AtransformDesc,
						Btransform,
						BtransformDesc,
						beta,
						Ctransform,
						CtransformDesc,
						Ctransform,
						CtransformDesc,
						nullptr,
						workSpace,
						workSpaceSize,
						stream ) );

#if TIME_TRANSFORM == 1
		// Transform planar to interleaved data in output matrix
		checkCudaErrors(
				cublasLtMatrixTransform(
						ltHandle,
						transformDescO,
						alpha,
						Ctransform,
						CtransformDesc,
						beta,
						nullptr,
						nullptr,
						C,
						Cdesc,
						stream ) );
#endif

	}

	checkCudaErrors( cudaEventRecord( stopEvent, stream ) );
	checkCudaErrors( cudaEventSynchronize( stopEvent ) );
	float time;
	checkCudaErrors( cudaEventElapsedTime( &time, startEvent, stopEvent ) );

#if TIME_TRANSFORM == 0
	// Transform planar to interleaved data in output matrix
	checkCudaErrors(
			cublasLtMatrixTransform(
					ltHandle,
					transformDescO,
					alpha,
					Ctransform,
					CtransformDesc,
					beta,
					nullptr,
					nullptr,
					C,
					Cdesc,
					stream ) );
#endif

	double timeAvg = ( time * 1e-3 ) / kernelRepeats;	// Convert to seconds, then divide by loops
	double gflop = ( 8 * static_cast<unsigned long long int>( m * n ) * k ) * 1e-9;	// Complex

	printf(
#if IDENTITY
			"%d %d %d %d %d %d %d %0.2f %0.0f ",
#else
			"%d %d %d %d %d %d %d %0.2f %0.0f\n",
#endif
			m,
			n,
			k,
			cudaTypeI,
			cudaTypeI,
			cudaTypeO,
			cudaTypeCom,
			time,
			gflop/timeAvg );

	// Descriptors are no longer needed as all GPU work was already enqueued.
	checkCudaErrors( cublasLtMatmulPreferenceDestroy( preference ) );
	checkCudaErrors( cublasLtMatrixLayoutDestroy( Cdesc ) );
	checkCudaErrors( cublasLtMatrixLayoutDestroy( Bdesc ) );
	checkCudaErrors( cublasLtMatrixLayoutDestroy( Adesc ) );
	checkCudaErrors( cublasLtMatmulDescDestroy( operationDesc ) );
	checkCudaErrors( cudaFree ( Atransform ) );
	checkCudaErrors( cudaFree ( Btransform ) );
	checkCudaErrors( cudaFree ( Ctransform ) );
	checkCudaErrors( cublasLtMatrixLayoutDestroy( AtransformDesc ) );
	checkCudaErrors( cublasLtMatrixLayoutDestroy( BtransformDesc ) );
	checkCudaErrors( cublasLtMatrixLayoutDestroy( CtransformDesc ) );
	checkCudaErrors( cublasLtMatrixTransformDescDestroy( transformDescI ) );
	checkCudaErrors( cublasLtMatrixTransformDescDestroy( transformDescO ) );
	checkCudaErrors( cudaEventDestroy( startEvent ) );
	checkCudaErrors( cudaEventDestroy( stopEvent ) );
}

void calculate( int const & m, int const & n, int const & k, int & count, int const & square ) {

	dataTypeS alpha = 1.0f;
	dataTypeS beta = 0.0f;
	int lda = m, ldb = k, ldc = m;
	void *d_workspace = nullptr;

	size_t sizeA = m * k;
	size_t sizeB = k * n;
	size_t sizeC = m * n;
	size_t workspace = 4096;

	cublasLtHandle_t handle;

	/* Initialize CUBLAS */
	checkCudaErrors( cublasLtCreate( &handle ) );

	/* Allocate device memory for workspace */
	checkCudaErrors( cudaMalloc( (void **)&d_workspace, workspace) );

	/* Allocate device memory for the matrices */
	thrust::device_vector<dataTypeI> d_A( sizeA, __float2half2_rn(0.0f) );
	thrust::device_vector<dataTypeI> d_B( sizeB, __float2half2_rn(0.0f) );
#if SCENARIO == 0
	thrust::device_vector<dataTypeO> d_C( sizeC, __float2half2_rn(0.0f) );
#elif SCENARIO == 1
	thrust::device_vector<dataTypeO> d_C( sizeC, 0.0f );
#endif

	/* Retrieve raw pointer for device data */
	dataTypeI * d_A_ptr = thrust::raw_pointer_cast( &d_A[0] );
	dataTypeI * d_B_ptr = thrust::raw_pointer_cast( &d_B[0] );
	dataTypeO * d_C_ptr = thrust::raw_pointer_cast( &d_C[0] );

#if IDENTITY
	/* Generate identity matrix on device */
	thrust::transform(
			thrust::make_counting_iterator( 0 ),
			thrust::make_counting_iterator( static_cast<int>( sizeA ) ),
			d_A.begin( ),
			setIdentity( m ) );
	thrust::transform(
			thrust::make_counting_iterator( 0 ),
			thrust::make_counting_iterator( static_cast<int>( sizeB ) ),
			d_B.begin( ),
			setIdentity( m ) );
#else
	/* Generate random data on device */
	thrust::transform( thrust::make_counting_iterator( 0 ), thrust::make_counting_iterator( static_cast<int>( sizeA ) ), d_A.begin( ), GenRand( ) );
	thrust::transform( thrust::make_counting_iterator( 0 ), thrust::make_counting_iterator( static_cast<int>( sizeB ) ), d_B.begin( ), GenRand( ) );
#endif

	printf( "%d %d ", count, square );
	count++;

	LtSgemm(
			handle,
			CUBLAS_OP_N,
			CUBLAS_OP_N,
			m,
			n,
			k,
			&alpha,
			sizeA,
			d_A_ptr,
			lda,
			sizeB,
			d_B_ptr,
			ldb,
			&beta,
			sizeC,
			d_C_ptr,
			ldc,
			d_workspace,
			workspace );

#if IDENTITY
	/* Generate device vector to hold flag */
	thrust::device_vector<bool> d_p(1, true);

	checkIdentity<<<sizeC/threadsPerBlock + 1, threadsPerBlock>>>( sizeC, m, d_C_ptr, d_p.data());

	/* Copy device flag to host */
	thrust::host_vector<bool> h_p = d_p;

#if PRINT
	thrust::host_vector<dataTypeI> h_A = d_A;
	thrust::host_vector<dataTypeI> h_B = d_B;
	thrust::host_vector<dataTypeO> h_C = d_C;

	printf("\n"); // Formatting stdout

	for ( int a = 0; a < k; a++ ) {
		for ( int b = 0; b < n; b++ )
		printf( "{%0.1f %0.1f} ", __half2float(h_A[a * k + b].x), __half2float(h_A[a * k + b].y) );
		printf("\n");
	}
	printf("\n");

	for ( int a = 0; a < m; a++ ) {
		for ( int b = 0; b < k; b++ )
		printf( "{%0.1f %0.1f} ", __half2float(h_B[a * m + b].x), __half2float(h_A[a * m + b].y) );
		printf("\n");
	}
	printf("\n");

	for ( int a = 0; a < m; a++ ) {
		for ( int b = 0; b < n; b++ )
#if SCENARIO == 0
		printf( "{%0.1f %0.1f} ", __half2float(h_C[a * m + b].x), __half2float(h_C[a * m + b].y) );
#elif SCENARIO == 1
		printf( "{%0.1f, %0.1f} ", h_C[a * m + b].real(), h_C[a * m + b].imag() );
#endif
		printf( "\n" );
	}
	printf( "\n" );
#endif

	if ( h_p[0] ) printf("Passed Identity Test\n");
	else printf("\n");
#endif

	/* Destroy workspace */
	checkCudaErrors( cudaFree (d_workspace) );

	/* Shutdown */
	checkCudaErrors( cublasLtDestroy( handle ) );
}

/* Main */
int main( int argc, char **argv ) {

	int dev = findCudaDevice( argc, ( const char ** ) argv );

	if ( dev == -1 ) throw std::runtime_error( "!!!! CUDA device not found" );

	printf( "Run Square M N K A_Type B_Type C_Type Compute_Type Time GFLOPS\n" );

	int count = 0;
	int square = 1;

#if IDENTITY
	// Identity for square matrices
#if PRINT
	for ( int m = 16; m <= 16; m *= 2 )
#else
	for ( int m = 16; m <= 8192; m *= 2 )
#endif
		calculate( m, m, m, count, square );

	printf( "\n" ); // For better readability stdout

#else

	// Compute matrices
	for ( int m = 512; m <= 8192; m *= 2 )
	for ( int k = 1024; k <= 4096; k *= 2 )
	calculate( m, m, k, count, square );

	printf("\n");// For better readability stdout

	count = 0;
	square = 0;

	// Compute non-square matrices
	for ( int m = 4096; m <= 32768; m *= 2 )
	for ( int n = 512; n <= 8192; n *= 2 )
	for ( int k = 8; k <= 128; k *= 2 )
	calculate( m, n, k, count, square );

	printf("\n");// For better readability stdout

#endif

	return ( EXIT_SUCCESS );
}

Hi,

Thanks for the example. I was asking for an example because when I try to use it with “row order” matrices, it fails.

I’be modified your example by adding the lines:
cublasLtOrder_t rowOrder = CUBLASLT_ORDER_ROW;
status = cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &rowOrder, sizeof(rowOrder));
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &rowOrder, sizeof(rowOrder));
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &rowOrder, sizeof(rowOrder));
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
Am I doing it uncorrectly?

Hi Serge, I don’t believe the rowOrder is working with complex matrices, since I tried what you were doing and failed. It would be good to get Nvidia to comment, though.

Serge,

Your code works fine with my first example at the top of the thread. The program is failing because you’re comparing a column-major output, the host baseline, and a row-major output, cuBLASLt.

Rewrite the host reference for row-major format.

[url]c++ - matrix multiplication as column major - Stack Overflow

mnicely can you confirm whether that layout works for complex or not?

Cliff,

I can confirm it works with square matrices. In my second example, set IDENTITY 1. I’m having issues with non-square matrices though. Let me get back with you once I find the problem.

Yes,

my issue is not with the final test but with the call to “cublasLtMatmulAlgoGetHeuristic” which returns the error status CUBLAS_STATUS_NOT_SUPPORTED.

As I was saying, I only told cublas that all my matrices were using “row order” layout.

Do you know if cublasLt really support matrice in row order? This is my main motivation in using cublasLt

Serge,

cuBLASLt does support row-major format matrices. Proof is in the explanation I provided earlier. The error CUBLAS_STATUS_NOT_SUPPORTED usually means the only available algorithms, determined by heuristics, are not compatible with your hardware. As an example, IGEMM with TCs only works on compute capability 7.2 and greater. You will get CUBLAS_STATUS_NOT_SUPPORTED if you try to run on a 7.0 device.

Further, [i]cublasLtMatmulAlgoGetHeuristic/i is not guaranteed to return an algorithm. Depending on the problem parameters, it can return 0 algorithms. In which case, you will receive the CUBLAS_STATUS_INVALID_VALUE.

Can you provide the code that is failing to run?

Sure, it’s the above example with a few printf added and the modification for the row order described before:

it prints “cublasLtMatmulAlgoGetHeuristic returns status 15”

#include <iostream>
/* Includes, system */
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

/* Includes, cuda */
#include <cublasLt.h>
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include "helper_cuda.h"

/* Matrix size */
//#define N (4096)
#define N (100)

cublasStatus_t LtSgemm(cublasLtHandle_t ltHandle,
                       cublasOperation_t transa,
                       cublasOperation_t transb,
                       int m,
                       int n,
                       int k,
                       const float *alpha, /* host pointer */
                       const float *A,
                       int lda,
                       const float *B,
                       int ldb,
                       const float *beta, /* host pointer */
                       float *C,
                       int ldc,
                       void *workspace,
                       size_t workspaceSize)
{
    //=================================================
    // NB: in this specific program: m == n == k
    //=================================================
    
    cublasStatus_t status = CUBLAS_STATUS_SUCCESS;

    cublasLtMatmulDesc_t operationDesc = NULL;
    cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
    cublasLtMatmulPreference_t preference = NULL;

    int returnedResults = 0;
    cublasLtMatmulHeuristicResult_t heuristicResult = {};

    // Create operation descriptor; see cublasLtMatmulDescAttributes_t
    // for details about defaults; here we just set the transforms for
    // A and B.
    status = cublasLtMatmulDescCreate(&operationDesc, CUDA_R_32F);
    if (status != CUBLAS_STATUS_SUCCESS)
    {
        goto CLEANUP;
    }
    status = cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
    if (status != CUBLAS_STATUS_SUCCESS)
    {
        goto CLEANUP;
    }
    status = cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
    if (status != CUBLAS_STATUS_SUCCESS)
    {
        goto CLEANUP;
    }

    // Create matrix descriptors. Not setting any extra attributes.
    status = cublasLtMatrixLayoutCreate(
        &Adesc, CUDA_R_32F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
    if (status != CUBLAS_STATUS_SUCCESS)
    {
        goto CLEANUP;
    }
    status = cublasLtMatrixLayoutCreate(
        &Bdesc, CUDA_R_32F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
    if (status != CUBLAS_STATUS_SUCCESS)
    {
        goto CLEANUP;
    }
    status = cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_32F, m, n, ldc);
    if (status != CUBLAS_STATUS_SUCCESS)
    {
        goto CLEANUP;
    }

    //=================================================
    // NB: in this specific program: m == n == k
    //=================================================
    cublasLtOrder_t rowOrder = CUBLASLT_ORDER_ROW;
    status = cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &rowOrder, sizeof(rowOrder));
    if (status != CUBLAS_STATUS_SUCCESS)
    {
        goto CLEANUP;
    }
    status = cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &rowOrder, sizeof(rowOrder));
    if (status != CUBLAS_STATUS_SUCCESS)
    {
        goto CLEANUP;
    }
    status = cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &rowOrder, sizeof(rowOrder));
    if (status != CUBLAS_STATUS_SUCCESS)
    {
        goto CLEANUP;
    }


    // Create preference handle; In general, extra attributes can be
    // used here to disable tensor ops or to make sure algo selected
    // will work with badly aligned A, B, C. However, for simplicity
    // here we assume A,B,C are always well aligned (e.g., directly
    // come from cudaMalloc)
    status = cublasLtMatmulPreferenceCreate(&preference);
    if (status != CUBLAS_STATUS_SUCCESS)
    {
        goto CLEANUP;
    }
    status = cublasLtMatmulPreferenceSetAttribute(
        preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
    if (status != CUBLAS_STATUS_SUCCESS)
    {
        goto CLEANUP;
    }

    // We just need the best available heuristic to try and run matmul.
    // There is no guarantee that this will work. For example, if A is
    // badly aligned, you can request more (e.g. 32) algos and try to
    // run them one by one until something works.
    status = cublasLtMatmulAlgoGetHeuristic(
        ltHandle, operationDesc, Adesc, Bdesc, Cdesc, Cdesc, preference, 1, &heuristicResult, &returnedResults);
    if (status != CUBLAS_STATUS_SUCCESS)
    {
        fprintf(stderr, "cublasLtMatmulAlgoGetHeuristic returned status %d\n", status);
        goto CLEANUP;
    }

    if (returnedResults == 0)
    {
        fprintf(stderr, "cublasLtMatmulAlgoGetHeuristic found 0 algo\n");
        status = CUBLAS_STATUS_NOT_SUPPORTED;
        goto CLEANUP;
    }

    status = cublasLtMatmul(ltHandle,
                            operationDesc,
                            alpha,
                            A,
                            Adesc,
                            B,
                            Bdesc,
                            beta,
                            C,
                            Cdesc,
                            C,
                            Cdesc,
                            &heuristicResult.algo,
                            workspace,
                            workspaceSize,
                            0);

    CLEANUP:
    // Descriptors are no longer needed as all GPU work was already
    // enqueued.
    if (preference)
    {
        cublasLtMatmulPreferenceDestroy(preference);
    }
    if (Cdesc)
    {
        cublasLtMatrixLayoutDestroy(Cdesc);
    }
    if (Bdesc)
    {
        cublasLtMatrixLayoutDestroy(Bdesc);
    }
    if (Adesc)
    {
        cublasLtMatrixLayoutDestroy(Adesc);
    }
    if (operationDesc)
    {
        cublasLtMatmulDescDestroy(operationDesc);
    }
    return status == CUBLAS_STATUS_SUCCESS ? static_cast<cublasStatus_t>(0) : static_cast<cublasStatus_t>(1);
}

/* Host implementation of a simple version of sgemm */
static void simple_sgemm(int n, float alpha, const float *A, const float *B,
                         float beta, float *C)
{
    int i;
    int j;
    int k;

    for (i = 0; i < n; ++i)
    {
        for (j = 0; j < n; ++j)
        {
            float prod = 0;

            for (k = 0; k < n; ++k)
            {
                prod += A[k * n + i] * B[j * n + k];
            }

            C[j * n + i] = alpha * prod + beta * C[j * n + i];
        }
    }
}

/* Main */
int main(int argc, char **argv)
{
    cublasStatus_t status;
    float *h_A;
    float *h_B;
    float *h_C;
    float *h_C_ref;
    float *d_A = 0;
    float *d_B = 0;
    float *d_C = 0;
    float alpha = 1.0f;
    float beta = 0.0f;
    int n2 = N * N;
    int i;
    float error_norm;
    float ref_norm;
    float diff;
    cublasLtHandle_t handle;

    int dev = findCudaDevice(argc, (const char **) argv);

    if (dev == -1)
    {
        return EXIT_FAILURE;
    }

    /* Initialize CUBLAS */
    printf("simpleCUBLAS test running..\n");

    status = cublasLtCreate(&handle);

    if (status != CUBLAS_STATUS_SUCCESS)
    {
        fprintf(stderr, "!!!! CUBLAS initialization error\n");
        return EXIT_FAILURE;
    }

    /* Allocate host memory for the matrices */
    h_A = reinterpret_cast<float *>(malloc(n2 * sizeof(h_A[0])));

    if (h_A == 0)
    {
        fprintf(stderr, "!!!! host memory allocation error (A)\n");
        return EXIT_FAILURE;
    }

    h_B = reinterpret_cast<float *>(malloc(n2 * sizeof(h_B[0])));

    if (h_B == 0)
    {
        fprintf(stderr, "!!!! host memory allocation error (B)\n");
        return EXIT_FAILURE;
    }

    h_C = reinterpret_cast<float *>(malloc(n2 * sizeof(h_C[0])));

    if (h_C == 0)
    {
        fprintf(stderr, "!!!! host memory allocation error (C)\n");
        return EXIT_FAILURE;
    }

    /* Fill the matrices with test data */
    for (i = 0; i < n2; i++)
    {
        h_A[i] = rand() / static_cast<float>(RAND_MAX);
        h_B[i] = rand() / static_cast<float>(RAND_MAX);
        h_C[i] = rand() / static_cast<float>(RAND_MAX);
    }

    /* Allocate device memory for the matrices */
    if (cudaMalloc(reinterpret_cast<void **>(&d_A), n2 * sizeof(d_A[0])) !=
        cudaSuccess)
    {
        fprintf(stderr, "!!!! device memory allocation error (allocate A)\n");
        return EXIT_FAILURE;
    }

    if (cudaMalloc(reinterpret_cast<void **>(&d_B), n2 * sizeof(d_B[0])) !=
        cudaSuccess)
    {
        fprintf(stderr, "!!!! device memory allocation error (allocate B)\n");
        return EXIT_FAILURE;
    }

    if (cudaMalloc(reinterpret_cast<void **>(&d_C), n2 * sizeof(d_C[0])) !=
        cudaSuccess)
    {
        fprintf(stderr, "!!!! device memory allocation error (allocate C)\n");
        return EXIT_FAILURE;
    }

    /* Initialize the device matrices with the host matrices */
    status = cublasSetVector(n2, sizeof(h_A[0]), h_A, 1, d_A, 1);

    if (status != CUBLAS_STATUS_SUCCESS)
    {
        fprintf(stderr, "!!!! device access error (write A)\n");
        return EXIT_FAILURE;
    }

    status = cublasSetVector(n2, sizeof(h_B[0]), h_B, 1, d_B, 1);

    if (status != CUBLAS_STATUS_SUCCESS)
    {
        fprintf(stderr, "!!!! device access error (write B)\n");
        return EXIT_FAILURE;
    }

    status = cublasSetVector(n2, sizeof(h_C[0]), h_C, 1, d_C, 1);

    if (status != CUBLAS_STATUS_SUCCESS)
    {
        fprintf(stderr, "!!!! device access error (write C)\n");
        return EXIT_FAILURE;
    }

    /* Performs operation using plain C code */
    simple_sgemm(N, alpha, h_A, h_B, beta, h_C);
    h_C_ref = h_C;

    // ******* REMOVE ********
    /* Performs operation using cublas */
    //  status = cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, N, N, N, &alpha, d_A,
    //                       N, d_B, N, &beta, d_C, N);
    // ******* REMOVE ********

    status = LtSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, N, N, N, &alpha, d_A,
                     N, d_B, N, &beta, d_C, N, nullptr, 0);

    if (status != CUBLAS_STATUS_SUCCESS)
    {
        fprintf(stderr, "!!!! kernel execution error.\n");
        return EXIT_FAILURE;
    }

    /* Allocate host memory for reading back the result from device memory */
    h_C = reinterpret_cast<float *>(malloc(n2 * sizeof(h_C[0])));

    if (h_C == 0)
    {
        fprintf(stderr, "!!!! host memory allocation error (C)\n");
        return EXIT_FAILURE;
    }

    /* Read the result back */
    status = cublasGetVector(n2, sizeof(h_C[0]), d_C, 1, h_C, 1);

    if (status != CUBLAS_STATUS_SUCCESS)
    {
        fprintf(stderr, "!!!! device access error (read C)\n");
        return EXIT_FAILURE;
    }

    /* Check result against reference */
    error_norm = 0;
    ref_norm = 0;

    for (i = 0; i < n2; ++i)
    {
        diff = h_C_ref[i] - h_C[i];
        error_norm += diff * diff;
        ref_norm += h_C_ref[i] * h_C_ref[i];
    }

    error_norm = static_cast<float>(sqrt(static_cast<double>(error_norm)));
    ref_norm = static_cast<float>(sqrt(static_cast<double>(ref_norm)));

    if (fabs(ref_norm) < 1e-7)
    {
        fprintf(stderr, "!!!! reference norm is 0\n");
        return EXIT_FAILURE;
    }

    /* Memory clean up */
    free(h_A);
    free(h_B);
    free(h_C);
    free(h_C_ref);

    if (cudaFree(d_A) != cudaSuccess)
    {
        fprintf(stderr, "!!!! memory free error (A)\n");
        return EXIT_FAILURE;
    }

    if (cudaFree(d_B) != cudaSuccess)
    {
        fprintf(stderr, "!!!! memory free error (B)\n");
        return EXIT_FAILURE;
    }

    if (cudaFree(d_C) != cudaSuccess)
    {
        fprintf(stderr, "!!!! memory free error (C)\n");
        return EXIT_FAILURE;
    }

    /* Shutdown */
    status = cublasLtDestroy(handle);

    if (status != CUBLAS_STATUS_SUCCESS)
    {
        fprintf(stderr, "!!!! shutdown error (A)\n");
        return EXIT_FAILURE;
    }

    if (error_norm / ref_norm < 1e-3f)
    {
        printf("simpleCUBLAS test passed.\n");
        exit(EXIT_SUCCESS);
    }
    else
    {
        printf("simpleCUBLAS test failed.\n");
        exit(EXIT_FAILURE);
    }
}

Serge,

I’m able to run the attached code, on a Titan V: CC 7.0, although it fails the comparison test.

What GPU are you using? What compute capability are you compiling?

Also, I’ve provided a cleaner example at the link below. It doesn’t require cublas.h and has cleaner error checking.

[url]https://github.com/mnicely/cublasLt_examples/tree/master/cublasLt_sgemm[/url]

Mnicely, does row major work for planar format cgemm? I get errors when trying on a v100. I’m still trying to reproduce your results to give the correct answer. The identity matrix gives the correct results, but random data appears to give the wrong answer. I’m trying this all on an 8x8 for simplicity. Also, I want avoid using the transform function since it has a fairly large performance hit in your code. It appears that all it does is list all reals linearly, followed by imaginary, so I’ve done that step manually.

Still not sure where the error is.

I’ve run your program and I still have the same error:

CUDA error at C:/src/test_blaslt/test_blaslt/test_blaslt/kernel.cu:169 code=15(CUBLAS_STATUS_NOT_SUPPORTED) “cublasLtMatmulAlgoGetHeuristic( ltHandle, operationDesc, Adesc, Bdesc, Cdesc, Cdesc, preference, 1, &heuristicResult, &returnedResults)”

using the deviceQuery example, I get:

C:\ProgramData\NVIDIA Corporation\CUDA Samples\v10.1\bin\win64\Release\deviceQuery.exe Starting...

 CUDA Device Query (Runtime API) version (CUDART static linking)

Detected 1 CUDA Capable device(s)

Device 0: "GeForce GTX 1050"
  CUDA Driver Version / Runtime Version          10.1 / 10.1
  CUDA Capability Major/Minor version number:    6.1
  Total amount of global memory:                 4096 MBytes (4294967296 bytes)
  ( 5) Multiprocessors, (128) CUDA Cores/MP:     640 CUDA Cores
  GPU Max Clock rate:                            1493 MHz (1.49 GHz)
  Memory Clock rate:                             3504 Mhz
  Memory Bus Width:                              128-bit
  L2 Cache Size:                                 524288 bytes
  Maximum Texture Dimension Size (x,y,z)         1D=(131072), 2D=(131072, 65536), 3D=(16384, 16384, 16384)
  Maximum Layered 1D Texture Size, (num) layers  1D=(32768), 2048 layers
  Maximum Layered 2D Texture Size, (num) layers  2D=(32768, 32768), 2048 layers
  Total amount of constant memory:               65536 bytes
  Total amount of shared memory per block:       49152 bytes
  Total number of registers available per block: 65536
  Warp size:                                     32
  Maximum number of threads per multiprocessor:  2048
  Maximum number of threads per block:           1024
  Max dimension size of a thread block (x,y,z): (1024, 1024, 64)
  Max dimension size of a grid size    (x,y,z): (2147483647, 65535, 65535)
  Maximum memory pitch:                          2147483647 bytes
  Texture alignment:                             512 bytes
  Concurrent copy and kernel execution:          Yes with 5 copy engine(s)
  Run time limit on kernels:                     No
  Integrated GPU sharing Host Memory:            No
  Support host page-locked memory mapping:       Yes
  Alignment requirement for Surfaces:            Yes
  Device has ECC support:                        Disabled
  CUDA Device Driver Mode (TCC or WDDM):         WDDM (Windows Display Driver Model)
  Device supports Unified Addressing (UVA):      Yes
  Device supports Compute Preemption:            Yes
  Supports Cooperative Kernel Launch:            No
  Supports MultiDevice Co-op Kernel Launch:      No
  Device PCI Domain ID / Bus ID / location ID:   0 / 1 / 0
  Compute Mode:
     < Default (multiple host threads can use ::cudaSetDevice() with device simultaneously) >

deviceQuery, CUDA Driver = CUDART, CUDA Driver Version = 10.1, CUDA Runtime Version = 10.1, NumDevs = 1
Result = PASS

Sergey, how are you compiling it?

Serge,

I have confirmed your code works on a GTX1080: CC 6.1. There is no reason to believe it will not work on a GTX1050. If you will humor me, checkout https://github.com/mnicely/cublasLt_examples/tree/master/cublasLt_sgemm and in the Release folder run

make clean; make; ./cublasLt_sgemm

Confirm that it works with row-major and column-major.

Cliff,

I have confirmed, on square matrices, that although complex, half is running with row-major order the answer is the same as column-major order, which is incorrect. I will file a ticket and let you know when I hear back from the developers.