Hi Everyone,
I’ve returned to the forum after a long hiatus from CUDA.
Just getting back into things (So I’m quite inexperienced when it comes to the newer functionality) and I’ve noticed some strange behavior when writing a kernel.
Just for some background, I’m trying to create my own tridiagonal solver function (similar to cusparseSgtsv) with a few changes:
- I store my problem in row-major instead of column major.
- The number of columns in my matrix is a multiple of 32 (typically 32, 64, or 96, but sometimes 320,640, etc.), and the number of rows is typically in the range 1e6-4e6
- The Tridiagonal matrix is positive definite (diagonally dominant) and symmetric, and there are many sporadic occurences of zeros in the off diagonal. Rather than use CR / PCR / partial pivoting, this allows me to partition the problem into many smaller subsets, and use a basic Thomas Algorithm for each subset.
I’ve built an effective kernel for 32 columns, and I was looking to generalize it to use 32*n columns instead. What I noticed when optimizing my code is that hard-coding the problem size in a template variable is almost twice(!) as fast as using it as a function parameter. Unfortunately I’m not a good enough coder to look around under the hood, so I’ve been chasing my tail trying to find out why this is happening, but also how to fix the problem.
Any help would be appreciated.
Attached output below, as well as sample source code to demonstrate the issue. There’s an IO of the output you can uncomment if you want to check it out and compare with the MATLAB solution, but it’s quite time consuming.
Edit: Compiled with Visual Studio 2015 + CUDA 9.1, run on a Titan X + Win 7
//Output
Found 206897 zero off diagonals.
Running Kernel...
Custom Gtsv Time taken: 3.623689
Custom Gtsv non-template Time taken: 6.851085
Cusparse Gtsv Time taken: 45.359820
//main.cu
#include "main.cuh"
int main()
{
int numel = 1e6;
int numcol = 64;
//initialize T
thrust::host_vector<float> a_hst(numel);
thrust::host_vector<float> b_hst(numel);
thrust::host_vector<float> c_hst(numel);
thrust::host_vector<float> data_hst(numel * numcol);
a_hst[0] = 0;
c_hst[numel - 1] = 0;
for (int i = 1; i < numel; i++)
{
float rval = (float)((i + 1) % 29) / 58.0f;
rval = rval > 0.1 ? rval : 0;
a_hst[i] = -rval;
}
for (int i = 0; i < numel; i++) b_hst[i] = 1.0f + (float)(i % 31) / 31.0f;
for (int i = 0; i < numel - 1; i++) c_hst[i] = a_hst[i + 1];
thrust::device_vector<float> c_dvc_1 = c_hst;
c_hst[0] = c_hst[0] / b_hst[0];
for (int i = 1; i < numel - 1; i++) c_hst[i] = c_hst[i] / (b_hst[i] - a_hst[i] * c_hst[i - 1]);
for (int i = 0; i < numel * numcol; i++) data_hst[i] = (i + 1) % 37;// ((double)rand() / (RAND_MAX));
thrust::device_vector<float> a_dvc = a_hst;
thrust::device_vector<float> b_dvc = b_hst;
thrust::device_vector<float> c_dvc = c_hst;
thrust::device_vector<float> data_dvc = data_hst;
thrust::device_vector<float> output_dvc(numcol * numel);
thrust::device_vector<float> output_dvc_transpose(numcol * numel);
thrust::device_vector<int> indices(numel + 1);
// counting iterators define a sequence [0, size)
thrust::counting_iterator<int> first(0);
thrust::counting_iterator<int> last = first + numel;
// compute indices of nonzero elements
typedef thrust::device_vector<int>::iterator IndexIterator;
IndexIterator indices_end = thrust::remove_copy_if(first, last, a_dvc.begin(), indices.begin(), thrust::identity<float>());
std::cout << "Found " << (indices_end - indices.begin()) << " zero off diagonals.\nRunning Kernel...\n";
int numproblems = indices_end - indices.begin();
indices[numproblems] = numel;
thrust::device_vector<int> indices1(indices.begin(), indices_end + 1);
dim3 threadsPerBlock(32, BLOCKSIZE / 32);
int numblocks = (BLOCKS_PER_SM * NUM_SM);
cusparseHandle_t csh;
cusparseCreate(&csh);
cublasHandle_t cbh;
cublasCreate(&cbh);
Timer t;
t.start();
for (int i = 0; i < 1000; i++)
{
Tsolve_kernel<64> << <numblocks, threadsPerBlock >> > (thrust::raw_pointer_cast(&output_dvc[0]), thrust::raw_pointer_cast(&data_dvc[0]), thrust::raw_pointer_cast(&a_dvc[0]), thrust::raw_pointer_cast(&b_dvc[0]), thrust::raw_pointer_cast(&c_dvc[0]), thrust::raw_pointer_cast(&indices[0]), numproblems);
}
cudaDeviceSynchronize();
printf("Custom Gtsv Time taken: %6f\n", t.stop());
t.start();
for (int i = 0; i < 1000; i++)
{
Tsolve_kernel1<<<numblocks, threadsPerBlock>>> (thrust::raw_pointer_cast(&output_dvc[0]), thrust::raw_pointer_cast(&data_dvc[0]), thrust::raw_pointer_cast(&a_dvc[0]), thrust::raw_pointer_cast(&b_dvc[0]), thrust::raw_pointer_cast(&c_dvc[0]), thrust::raw_pointer_cast(&indices[0]), numproblems,numcol);
}
cudaDeviceSynchronize();
printf("Custom Gtsv non-template Time taken: %6f\n", t.stop());
thrust::host_vector<float> out = output_dvc;
t.start();
for (int i = 0; i < 1000; i++)
{
float alpha = 1;
float beta = 0;
cublasSgeam(cbh, CUBLAS_OP_T, CUBLAS_OP_T, numel, numcol, &alpha, thrust::raw_pointer_cast(&data_dvc[0]), numcol, &beta, NULL, numcol, thrust::raw_pointer_cast(&output_dvc_transpose[0]), numel);
cusparseSgtsv(csh, numel, numcol, thrust::raw_pointer_cast(&a_dvc[0]), thrust::raw_pointer_cast(&b_dvc[0]), thrust::raw_pointer_cast(&c_dvc_1[0]), thrust::raw_pointer_cast(&output_dvc_transpose[0]), numel);
cublasSgeam(cbh, CUBLAS_OP_T, CUBLAS_OP_T, numcol, numel, &alpha, thrust::raw_pointer_cast(&output_dvc_transpose[0]), numel, &beta, NULL, numel, thrust::raw_pointer_cast(&output_dvc[0]), numcol);
}
cudaDeviceSynchronize();
printf("Cusparse Gtsv Time taken: %6f", t.stop());
//std::ofstream myfile;
//myfile.open("c:\work\doutput.txt");
//for (int i = 0; i < out.end() - out.begin(); i++)
//{
// myfile << out[i] << "\n";
//}
myfile.close();
cudaDeviceSynchronize();
cublasDestroy(cbh);
cusparseDestroy(csh);
std::cin.get();
}
//main.cuh
#ifndef MAIN
#define MAIN
#include "device_launch_parameters.h"
#include <random>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/copy.h>
#include <thrust/functional.h>
#include <thrust/device_vector.h>
#include <iterator>
#include <iostream>
#include <fstream>
#include <cusparse_v2.h>
#include <cublas_v2.h>
#define NUM_SM 100
#define BLOCKS_PER_SM 6
#define BLOCKSIZE 256
#include <chrono>
class Timer
{
public:
Timer() : beg_(clock_::now()) {}
void start() { beg_ = clock_::now(); }
double stop() const {
return std::chrono::duration_cast<second_>
(clock_::now() - beg_).count();
}
private:
typedef std::chrono::high_resolution_clock clock_;
typedef std::chrono::duration<double, std::ratio<1> > second_;
std::chrono::time_point<clock_> beg_;
};
#ifndef CUDA
#define CUDA() do { \
cudaError_t _e = cudaGetLastError(); \
if (_e == cudaSuccess) break; \
char errstr[128]; \
_snprintf(errstr, 128, \
"%s(%d) CUDA Error(%d)\n", \
__FILE__, __LINE__, _e); \
throw std::runtime_error(errstr); \
} while (0)
#endif
__global__ void Tsolve_kernel1(float * output, float * input, const float * __restrict__ T0, const float * __restrict__ T1, const float * __restrict__ T2, const int * __restrict__ blockid, int dimid, int numcol);
template <int numcol>
__launch_bounds__(BLOCKSIZE, BLOCKS_PER_SM)
__global__ void Tsolve_kernel(float * output, const float * __restrict__ input, const float * __restrict__ T0, const float * __restrict__ T1, const float * __restrict__ T2, const int * __restrict__ blockid, int dimid)
{
const int linenumber = threadIdx.y + blockIdx.x * (BLOCKSIZE / 32);
const int lineskip = gridDim.x * (BLOCKSIZE / 32);
input += threadIdx.x;
output += threadIdx.x;
blockid += linenumber;
const float *T0_base = T0;
const float *T1_base = T1;
const float *T2_base = T2;
const float *input_base = input;
float *output_base = output;
for (int i = linenumber; i < dimid; i += lineskip)
{
int rowoffset = blockid[0];
int numel = blockid[1] - rowoffset;
for (int a = 0; a < numcol; a += 32)
{
T0 = T0_base + rowoffset;
T1 = T1_base + rowoffset;
T2 = T2_base + rowoffset;
//input = input_base + 32 * rowoffset;
//output = output_base + 32 * rowoffset;
input = input_base + a + numcol*rowoffset;
output = output_base + a + numcol*rowoffset;
//pass forward
float output_reg = input[0] / T1[0];
output[0] = output_reg;
for (int j = 1; j < numel; j++)
{
//printf("EQ: %d,%d,%d,%d,%d, %d, %d\n", dimid, rowoffset, threadIdx.x, threadIdx.y, numel, i, j);
input += numcol;// 32;
output += numcol;// 32;
T0 += 1;
T1 += 1;
output_reg = (input[0] - T0[0] * output_reg) / (T1[0] - T0[0] * T2[0]);
T2 += 1;
output[0] = output_reg;
}
//pass backward
for (int j = 1; j < numel; j++)
{
//printf("EQ: %d,%d,%d,%d,%d, %d, %d\n", dimid, rowoffset, threadIdx.x, threadIdx.y, numel, i, j);
output -= numcol;// 32;
T2 -= 1;
output_reg = output[0] - T2[0] * output_reg;
output[0] = output_reg;
}
}
blockid += lineskip;
}
}
#endif
//kernels.cu
#include "main.cuh"
__launch_bounds__(BLOCKSIZE, BLOCKS_PER_SM)
__global__ void Tsolve_kernel1(float * output, float * input, const float * __restrict__ T0, const float * __restrict__ T1, const float * __restrict__ T2, const int * __restrict__ blockid, int dimid, int numcol)
{
const int linenumber = threadIdx.y + blockIdx.x * (BLOCKSIZE / 32);
const int lineskip = gridDim.x * (BLOCKSIZE / 32);
input += threadIdx.x;
output += threadIdx.x;
blockid += linenumber;
const float *T0_base = T0;
const float *T1_base = T1;
const float *T2_base = T2;
float *input_base = input;
float *output_base = output;
for (int i = linenumber; i < dimid; i += lineskip)
{
int rowoffset = blockid[0];
int numel = blockid[1] - rowoffset;
for (int a = 0; a < numcol; a += 32)
{
T0 = T0_base + rowoffset;
T1 = T1_base + rowoffset;
T2 = T2_base + rowoffset;
//input = input_base + 32 * rowoffset;
//output = output_base + 32 * rowoffset;
input = input_base + a + numcol*rowoffset;
output = output_base + a + numcol*rowoffset;
//pass forward
float output_reg = input[0] / T1[0];
output[0] = output_reg;
for (int j = 1; j < numel; j++)
{
//printf("EQ: %d,%d,%d,%d,%d, %d, %d\n", dimid, rowoffset, threadIdx.x, threadIdx.y, numel, i, j);
input += numcol;// 32;
output += numcol;// 32;
T0 += 1;
T1 += 1;
output_reg = (input[0] - T0[0] * output_reg) / (T1[0] - T0[0] * T2[0]);
T2 += 1;
output[0] = output_reg;
}
//pass backward
for (int j = 1; j < numel; j++)
{
//printf("EQ: %d,%d,%d,%d,%d, %d, %d\n", dimid, rowoffset, threadIdx.x, threadIdx.y, numel, i, j);
output -= numcol;// 32;
T2 -= 1;
output_reg = output[0] - T2[0] * output_reg;
output[0] = output_reg;
}
}
blockid += lineskip;
}
}
//Tested against matlab tridiagonal solve with:
dims=1e6;
nc=64;
D0=[0;mod(2:dims,29)'./58];
D1=1+mod((0:dims-1)',31)./31;
D2=circshift(D0,-1);
D0(D0<0.1)=0;
D2(D2<0.1)=0;
T=spdiags([-D2 D1 -D0],-1:1,dims,dims);
tosolve=zeros(nc,dims);tosolve(:)=mod((1:nc*dims),37);tosolve=tosolve';
out=(T\tosolve)';
co=csvread('c:\work\doutput.txt');
corr(co(:),out(:))