Undeterministic 8-bit convolution output when channels are increased more than 4 with cudnnConvoluti...

Hello All,

I have made a GPUop in Theano which calls cudnnConvolutionForward() to perform 8-bit convolution. Since, cudaNdarray only supports fp32 as of now, I have to pass float32 arguements of input and filter for convolution to the theano and typecast it to INT parameters before passing it to cudnnConvolutionForward().

I do this using a Kernels called “Float2Int” and ‘DeviceFloatCopy’ ( most probably the threads I am allocating wrong somewhere…:( )

One more kernel ‘Device2DeviceCopy’ to copy the contents of pwdesc (pointer to GPU memory pointed by filter).

Since, I am not very good in threads programming, I am sure something is messing here. :(

In file convolution8bit.cu inside function, when I set the piece of code in function

“void APPLY_SPECIFIC(doConv)”

size_t channels = 4;
int out_channels = 4;

I always get the correct output. I am giving input and filter parameters from conv.py file, computing the output of my theano GPUOP and the normal scipy.signal.convolve2d at the end. Both should give same answers.

BUT, When I change the above two lines of code as

size_t channels = 64;
int out_channels = 64;

Sometimes, the output of convolution is correct but sometimes not. I think this undeterministic behaviour is from the threads.

Since, I am not that good in threads, Can some of you review my code and please guide me forward.

I attach both of my files here :

conv.py
Convolution8bit.cu

Requirements:

  1. GPU with INT8 support (I am using NVIDIA 1080 TI)
  2. CUDNN 6.0
  3. CUDA 8.0
  4. Theano
  5. Keras

Steps to run:

  1. Copy both files from below:
  2. Change the convolution8Bit.cu file from

#include “/home/d1230/cudnn6.0/cuda/include/cudnn2.h”

to your path to cudnn.h(cudnn 6.0)

  1. Run command
    python conv.py

File conv.py

############################################################################

import numpy as np
from scipy import signal
from theano import gof
#import theano
from theano.gof import COp, Op
import theano.tensor as T
import theano.scalar as S
from theano import function, config
import os
import theano.sandbox.cuda as cuda
from theano.sandbox.cuda.basic_ops import as_cuda_ndarray_variable, GpuOp, gpu_contiguous
from theano import Apply
from theano.sandbox.cuda.type import CudaNdarrayType
from theano import shared
from theano.gradient import grad_undefined

class convGpuOP(COp):
props = ()

def __init__(self,dtype=config.floatX):
    
self.func_file = "./convolution8Bit.cu"
#self.func_file = "./convolutionfp32Bit.cu"
    self.func_name = "APPLY_SPECIFIC(doConvolution8Bit)"
    self.dtype = dtype
    super(convGpuOP, self).__init__(self.func_file,
                                      self.func_name)


#def c_headers(self):
#    return ['<cublas_v2.h>', '<stdio.h>']
def c_headers(self):
    return ['/home/d1230/cudnn6.0/cuda/include/cudnn2.h']

def c_header_dirs(self):
    return ['/home/d1230/cudnn6.0/cuda/include']
def c_lib_dirs(self):
    return ['/home/d1230/cudnn6.0/cuda/lib64']
def c_libraries(self):
    return ['cudnn','cudart']
def c_compile_args(self):
return ['-DFIX', '-std=c++11']



def make_node(self, x, y):
    # Validate the inputs' type
    x = as_cuda_ndarray_variable(x)
    assert x.dtype == "float32"
    y = as_cuda_ndarray_variable(y)
    assert y.dtype == "float32"

assert x.ndim == 2
    assert y.ndim == 2
    
return Apply(self, [x, y], [CudaNdarrayType(broadcastable=(False,False))()])

doConv8Bit = convGpuOP()

if name == ‘main’:

a = T.matrix(dtype=‘float32’)
b = T.matrix(dtype=‘float32’)

c = doConv8Bit(a,b)

fun = function([a,b],c)
f = ((np.arange(224*224).reshape(224,224))%128).astype(‘float32’)
g = (np.ones(9).reshape(3,3)).astype(‘float32’)
#print f
#print g
#print f.shape, g.shape
print “My matrix convolution”

t = fun(f, g)
print t
print “Normal matrix convolution”
f_normal = signal.convolve2d(f, g, ‘valid’)

print np.array_equal(t, f_normal)

#####################################################################################

File convolution8Bit.cu

#######################################################################################

#section support_code
#include
#include<cublas_v2.h>
#include<math.h>
#include <cuda.h>
#include
#include
#include
#include
#include
#include
#include
//#include
#include
#include
#include
#include
#include
#include “/home/d1230/cudnn6.0/cuda/include/cudnn2.h”
//#include “cuda_runtime.h”
#include “device_launch_parameters.h”

#define MAX_MSG_LEN 1000

/** Error handling from https://developer.nvidia.com/cuDNN */
#define FatalError(s)
do {
std::stringstream _where, _message;
_where << FILE << ‘:’ << LINE;
_message << std::string(s) + “\n” << FILE << ‘:’ << LINE;
std::cerr << _message.str() << “\nAborting…\n”;
cudaDeviceReset();
exit(1);
} while (0)

#define checkCUDNN(status)
do {
std::stringstream _error;
if (status != CUDNN_STATUS_SUCCESS) {
_error << "CUDNN failure: " << cudnnGetErrorString(status);
FatalError(_error.str());
}
} while (0)

#define checkCudaErrors(status)
do {
std::stringstream _error;
if (status != 0) {
_error << "Cuda failure: " << status;
FatalError(_error.str());
}
} while (0)

/** Convolutional layer */
struct ConvolutionLayer {
int kernel_size;
int in_channels, in_height, in_width;
int out_channels, out_height, out_width;
std::vector pconv;

ConvolutionLayer(int in_channels_,
int out_channels_,
int kernel_size_,
int in_w_,
int in_h_)
: pconv(in_channels_ * kernel_size_ * kernel_size_ * out_channels_) {
in_channels = in_channels_;
out_channels = out_channels_;
kernel_size = kernel_size_;
in_width = in_w_;
in_height = in_h_;
out_width = in_w_ - kernel_size_ + 1;
out_height = in_h_ - kernel_size_ + 1;
}
};

/** Training context */
struct TrainingContext {
cudnnHandle_t cudnnHandle;
cudnnTensorDescriptor_t dataTensor, conv1Tensor, conv1BiasTensor;
cudnnFilterDescriptor_t conv1filterDesc;
cudnnConvolutionDescriptor_t conv1Desc;
cudnnConvolutionFwdAlgo_t conv1algo;
int m_gpuid;
int m_batchSize;
size_t m_workspaceSize;

// Disable copying
TrainingContext& operator=(const TrainingContext&) = delete;
TrainingContext(const TrainingContext&) = delete;

// Constructor
TrainingContext(int gpuid, int batch_size, ConvolutionLayer& conv1)
: m_gpuid(gpuid) {
m_batchSize = batch_size;

/** Create descriptors within the constructor.
  * As instructed in the Usual manual, descriptors for
  * input and output tensors, filter, and the forward
  * convolution operator are created along with
  * cuDNN handle.
  */
checkCudaErrors(cudaSetDevice(gpuid));
checkCUDNN(cudnnCreate(&cudnnHandle));
checkCUDNN(cudnnCreateTensorDescriptor(&dataTensor));
checkCUDNN(cudnnCreateFilterDescriptor(&conv1filterDesc));
checkCUDNN(cudnnCreateConvolutionDescriptor(&conv1Desc));
checkCUDNN(cudnnCreateTensorDescriptor(&conv1Tensor));

 // Initialize convolution forward pass
size_t workspaceSizeFromConv = SetFwdConvolutionTensors(
    conv1, dataTensor, conv1Tensor, conv1filterDesc, conv1Desc, conv1algo);
m_workspaceSize = std::max((int)workspaceSizeFromConv, 0);

}

~TrainingContext() {
checkCudaErrors(cudaSetDevice(m_gpuid));
checkCUDNN(cudnnDestroy(cudnnHandle));
checkCUDNN(cudnnDestroyTensorDescriptor(dataTensor));
checkCUDNN(cudnnDestroyTensorDescriptor(conv1Tensor));
checkCUDNN(cudnnDestroyFilterDescriptor(conv1filterDesc));
checkCUDNN(cudnnDestroyConvolutionDescriptor(conv1Desc));
}

/** Set tensors and ops for forward pass */
size_t SetFwdConvolutionTensors(ConvolutionLayer& conv,
cudnnTensorDescriptor_t& srcTensorDesc,
cudnnTensorDescriptor_t& dstTensorDesc,
cudnnFilterDescriptor_t& filterDesc,
cudnnConvolutionDescriptor_t& convDesc,
cudnnConvolutionFwdAlgo_t& algo) {
int n = m_batchSize;
int c = conv.in_channels;
int h = conv.in_height;
int w = conv.in_width;

// Set input tensor. Folowing the manual, chagnged
// * CUDNN_DATA_FLOAT -> CUDNN_DATA_INT8x4x4, and
// * CUDNN_TENSOR_NCHW_VECT_C -> CUDNN_TENSOR_NCHW_VECT_C

// checkCUDNN(cudnnSetTensor4dDescriptor(
// srcTensorDesc, CUDNN_TENSOR_NCHW_VECT_C, CUDNN_DATA_INT8x4x4, n, c, h, w));

// checkCUDNN(cudnnSetTensor4dDescriptor(
// srcTensorDesc, CUDNN_TENSOR_NCHW_VECT_C, CUDNN_DATA_INT8x4x4, n, c, h, w));

checkCUDNN(cudnnSetTensor4dDescriptor(
srcTensorDesc, CUDNN_TENSOR_NCHW_VECT_C, CUDNN_DATA_INT8x4, n, c, h, w));

// Set convolution filter. Folowing the manual, chagnged
// * CUDNN_DATA_FLOAT -> CUDNN_DATA_INT8x4x4, and
// * CUDNN_TENSOR_NCHW_VECT_C -> CUDNN_TENSOR_NCHW_VECT_C
checkCUDNN(cudnnSetFilter4dDescriptor(filterDesc,
                                      CUDNN_DATA_INT8x4,
                                      CUDNN_TENSOR_NCHW_VECT_C,
                                      conv.out_channels,
                                      conv.in_channels,
                                      conv.kernel_size,
                                      conv.kernel_size));

// checkCUDNN(cudnnSetFilter4dDescriptor(filterDesc,
// CUDNN_DATA_INT8x4x4,
// CUDNN_TENSOR_NCHW_VECT_C,
// conv.out_channels,
// conv.in_channels,
// conv.kernel_size,
// conv.kernel_size));

// Set convolution operator. Folowing the manual, chagnged
// * CUDNN_DATA_FLOAT -> CUDNN_DATA_INT32
int pad_height = 0;
int pad_width = 0;
int stride_h = 1;
int stride_v = 1;
int dilation_h = 1;
int dilation_w = 1;
checkCUDNN(cudnnSetConvolution2dDescriptor(convDesc,
                                           pad_height,
                                           pad_width,
                                           stride_h,
                                           stride_v,
                                           dilation_h,
                                           dilation_w,
                                           CUDNN_CONVOLUTION,
                                           CUDNN_DATA_INT32));

// Compute output dimension. Folowing the manual, chagnged
// * CUDNN_DATA_FLOAT -> CUDNN_DATA_INT8x4x4, and
// * CUDNN_TENSOR_NCHW_VECT_C -> CUDNN_TENSOR_NCHW_VECT_C
checkCUDNN(cudnnGetConvolution2dForwardOutputDim(
    convDesc, srcTensorDesc, filterDesc, &n, &c, &h, &w));

// Set output tensor (Changed CUDNN_DATA_FLOAT to CUDNN_DATA_INT8x4x4, following the manual)
checkCUDNN(cudnnSetTensor4dDescriptor(
  dstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, c, h, w));

//
// checkCUDNN(cudnnSetTensor4dDescriptor(
// dstTensorDesc, CUDNN_TENSOR_NCHW_VECT_C, CUDNN_DATA_INT8x4x4, n, c, h, w));

// Retrieve orward pass algorithm. We can either hardcode it to a specific
// algorithm or use cudnnGetConvolutionForwardAlgorithm. For the purpose
// of this test, either way works.
algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
// Following also works
// checkCUDNN(cudnnGetConvolutionForwardAlgorithm(
//     cudnnHandle,
//     srcTensorDesc,
//     filterDesc,
//     convDesc,
//     dstTensorDesc,
//     CUDNN_CONVOLUTION_FWD_PREFER_FASTEST,
//     0,
//     &algo));


// Compute workspace size. We can either hardcode it to a specific number,
// or use cudnnGetConvolutionForwardWorkspaceSize. For the purpose of this
// test, either way works.
//size_t sizeInBytes = 1073741824;
// Following also works
   size_t sizeInBytes = 0;
   checkCUDNN(cudnnGetConvolutionForwardWorkspaceSize(cudnnHandle,
                                                     srcTensorDesc,
                                                     filterDesc,
                                                     convDesc,
                                                     dstTensorDesc,
                                                     algo,
                                                     &sizeInBytes));

return sizeInBytes;

}

/** Execute forward pass /
void ForwardPropagation(int
data,
float* conv1,
int* pconv1,
void* workspace) {
float alpha = 1.0f;
float beta = 0.0f;
checkCudaErrors(cudaSetDevice(m_gpuid));

#if 0

printf("cudnnHandle %d\n", cudnnHandle );
printf("dataTensor %d\n", dataTensor );
printf("data %d\n", data );
printf("conv1filterDesc %d\n", conv1filterDesc );
printf("pconv1 %d\n", pconv1 );
printf("conv1Desc %d\n", conv1Desc );
printf("conv1algo %d\n", conv1algo );
printf("workspace %d\n", workspace );
printf("m_workspaceSize %d\n", m_workspaceSize );
printf("conv1Tensor %d\n", conv1Tensor );
printf("conv1 %d\n", conv1 );
printf("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++**\n" );

#endif

checkCUDNN(cudnnConvolutionForward(cudnnHandle,
                                   &alpha,
                                   dataTensor,
                                   data,
                                   conv1filterDesc,
                                   pconv1,
                                   conv1Desc,
                                   conv1algo,
                                   workspace,
                                   m_workspaceSize,
                                   &beta,
                                   conv1Tensor,
                                   conv1));

}
};

//Function declaration
global void Float2Int(float * in,int * out, int n);
global void Device2DeviceCopy(int * in,int * out, int n);
global void DeviceFloatCopy(float * in,float * out, int n);
global void printVal(int* val, int elem);
global void printArr(float * in, int elem);
// Support code function

bool matrix_same_dim(PyArrayObject* arr1, PyArrayObject* arr2)
{
if ((PyArray_NDIM(arr1) == 2) && (PyArray_NDIM(arr2) == 2))
return 1;
else
return 0;
}

#section support_code_apply

// Apply-specific support function
void APPLY_SPECIFIC(doConv)(
DTYPE_INPUT_0* x_ptr,
DTYPE_INPUT_1* y_ptr,
DTYPE_OUTPUT_0* z_ptr,
int m, int n, int p, int q, int out_row, int out_col)
{

float gpu_elapsed_time_ms=0;
int iterations = 1;
int gpu = 3; //For now
// input dimensions
size_t width = m; 
size_t height = n;
size_t channels =64;
int batch_size = 1;

// Create layer architecture
int out_channels =64;
int kernel_size = p;
int * in1;
int * in2 ;
float *out;






ConvolutionLayer conv1(
		(int)channels, out_channels, kernel_size, (int)width, (int)height);
TrainingContext context(gpu, batch_size, conv1);


//Allocate memory on device


(cudaMalloc((void **)&in1, (context.m_batchSize*channels*m*n)*sizeof(int)));
(cudaMalloc((void **)&in2, (channels*kernel_size*kernel_size)*sizeof(int)));
(cudaMalloc((void **)&out, ( sizeof(float) * context.m_batchSize * conv1.out_channels * conv1.out_height * conv1.out_width)));

//Create CUDA stream
cudaStream_t  stream;
cudaStreamCreate( &stream);

//CUDA events
cudaEvent_t start1, stop1;
	cudaEventCreate(&start1);
	cudaEventCreate(&stop1);
	cudaEventRecord(start1, 0);

Float2Int<<<((m*n*channels)/1024)+1,1024>>>(x_ptr,in1,sizeof(int) * m*n*context.m_batchSize*channels); //Equivalent ttchSizeo _a
Float2Int<<<((p*q*channels)/1024)+1,1024>>>(y_ptr,in2,sizeof(int) * kernel_size*kernel_size*channels ); //Equivalent to h_b

//printVal<<<1,1>>>(in1, 10);
    // start to count execution time of typecasting 
    cudaEventRecord(stop1, 0);
    cudaEventSynchronize(stop1);

    // compute time elapse of typecassting
    cudaEventElapsedTime(&gpu_elapsed_time_ms, start1, stop1);
printf("Time elapsed for typecasting in Convolution kernels  for Input %dx%d, Kernel size  %dx%d on GPU is : %f ms.\n\n", m, n, p, q, gpu_elapsed_time_ms);


 int* d_pconv1;
 checkCudaErrors(cudaMalloc(&d_pconv1, sizeof(int)*conv1.pconv.size()));
//Fill d_pconv1
 Device2DeviceCopy<<<((p*q*channels)/1024)+1,1024>>>(in2, d_pconv1,sizeof(int)*conv1.pconv.size());
/*
checkCudaErrors(cudaMemcpyAsync(d_pconv1,
                              in2,
                              sizeof(int)*conv1.pconv.size(),
                              cudaMemcpyDeviceToDevice));
*/





	cudaEvent_t start, stop;
	cudaEventCreate(&start);
	cudaEventCreate(&stop);

	// start to count execution time of GPU version
	cudaEventRecord(start, 0);

void* d_cudnn_workspace = nullptr;
if (context.m_workspaceSize > 0) {
	(cudaMalloc(&d_cudnn_workspace, context.m_workspaceSize));
}

// Start forward pass
//printf("Begin forwrad pass %d \n" , conv1.pconv.size());
checkCudaErrors(cudaDeviceSynchronize());
//auto t1 = std::chrono::high_resolution_clock::now();
// Temporary buffers and workspaces
printf("Begin  %d\n", conv1.pconv.size());

for (int iter = 0; iter < iterations; ++iter) {
	context.ForwardPropagation(in1, out, d_pconv1, d_cudnn_workspace);
}
//printf("Done forwrad pass\n");
checkCudaErrors(cudaDeviceSynchronize());
//auto t2 = std::chrono::high_resolution_clock::now();

/*
printf(
“Iteration time: %f ms\n”,
std::chrono::duration_caststd::chrono::microseconds(t2 - t1).count() /
1000.0f / iterations);
*/
// Free data structures

//printf("%f  %f\n", out[0],out[1]);

//printArr<<<1,1>>>(out, 20);
if (d_cudnn_workspace != nullptr)
	(cudaFree(d_cudnn_workspace));


//TO-DO Call dp4a convolution 
    cudaEventRecord(stop, 0);
	cudaEventSynchronize(stop);

	// compute time elapse on GPU computing
	cudaEventElapsedTime(&gpu_elapsed_time_ms, start, stop);
	//printf("Time elapsed on matrix convolutions  for Input %dx%d,  Kernel size %dx%d on GPU is : %f ms. device float copy is channels %d  height %d  width %d \n\n", m, n, p, q, gpu_elapsed_time_ms , conv1.out_channels, conv1.out_height, conv1.out_width);

DeviceFloatCopy<<<((conv1.out_channels * conv1.out_height * conv1.out_width)/1024)+1,1024>>>(out,z_ptr,sizeof(float) * context.m_batchSize * conv1.out_channels * conv1.out_height * conv1.out_width);

cudaStreamSynchronize (stream);
cudaStreamDestroy(stream);
cudaFree(in1);
cudaFree(in2);
cudaFree(out);
//checkCudaErrors(cudaFree(d_data));
    checkCudaErrors(cudaFree(d_pconv1));

}

void check_real_matrix(const char* name, CudaNdarray* m)
{
char msg[MAX_MSG_LEN];

// check whether it’s a matrix (2d CudaNdarray)
if (CudaNdarray_NDIM(m) != 2)
{
snprintf(msg, MAX_MSG_LEN, “%s must be two-dimensional”, name);
throw std::runtime_error( msg);
}

// check if unstrided in second dim
if (CudaNdarray_HOST_STRIDES(m)[1] != 1 )
{
snprintf(msg, MAX_MSG_LEN, “%s must be unstrided (i.e. have contiguous elements) in its second (i.e. last) dimension”, name);
throw std::runtime_error( msg);
}

// all checks correct
}
// Apply-specific main function
int APPLY_SPECIFIC(doConvolution8Bit)(CudaNdarray* input0,
CudaNdarray* input1,
CudaNdarray** output0)
{
int m,n,p,q;
// Validate that the inputs have the same shape

 //printf("Enter main......");
 check_real_matrix("input0", input0);
 check_real_matrix("input1", input1);

 //Input dimensions M*N     
 m = CudaNdarray_HOST_DIMS(input0)[0];
 n=  CudaNdarray_HOST_DIMS(input0)[1];
 
 //Filter (kernel) dimensions  size P*Q     

 p  = CudaNdarray_HOST_DIMS(input1)[0];
 q =  CudaNdarray_HOST_DIMS(input1)[1];
                                              
											  
 printf("M:%d\n", m);
 printf("n:%d\n", n);
 printf("P:%d\n", p);
 printf("Q:%d\n", q);

int out_dim[2] = {m-2,n-2}; //first try with -2 
int g = CudaNdarray_prep_output(output0,2,out_dim, 0);
/*printf("%d",CudaNdarray_HOST_DIMS(input0));*/
APPLY_SPECIFIC(doConv)(
                        (DTYPE_INPUT_0*)CudaNdarray_DEV_DATA(input0),
                                                    (DTYPE_INPUT_1*)CudaNdarray_DEV_DATA(input1),
                                                    (DTYPE_OUTPUT_0*)CudaNdarray_DEV_DATA(*output0),
                                                     m,n,p,q,m-2,n-2
                                                     );
return 0;

}

global void Float2Int(float * in,int * out, int n)
{
int i = blockIdx.x * blockDim.x + threadIdx.x;
//float upper_temp = *upper;
if(i < n)
{
//Actual typecasting
out[i] = (int) in[i];
//printf(“Float to Int %d\n”,out[i]);

}

}

global void Device2DeviceCopy(int * in,int * out, int n)
{
int i = blockIdx.x * blockDim.x + threadIdx.x;
//float upper_temp = *upper;

    if(i < n)
    {
            //Actual typecasting
             out[i] =  in[i];
            //printf("Float to Int %d\n",out[i]);

    }

}

global void DeviceFloatCopy(float * in,float * out, int n)
{
int i = blockIdx.x * blockDim.x + threadIdx.x;
//float upper_temp = *upper;

    if(i < n)
    {
            //Actual typecasting
             out[i] =  in[i];
            //printf("Float to Int %d\n",out[i]);

    }

}

global void printArr(float * in, int elem)
{
int i = 0;
for( i = 0; i< elem; i++ ){
printf(“in1 %f \n”, in[i]);
}
}

global void printVal(int* val, int elem)
{
int i = 0;
for( i = 0; i< elem; i++ ){
printf(“Array elements are %d \n”, val[i]);
}
}

######################################################################################