my device is 3090 , environment pytorch version : 2.3.1 + cu118 cudnn version : 8700
I’m trying to compare the accuracy of PyTorch’s convolution with cuDNN’s convolution. Forsmall channel sizes (e.g., 4 or 8), the results are identical. However, when the number ofchannels increases to 64 or more, l observe a maximum difference of 0.04 on the RTX 3090while there is no difference on the P40.
I suspect this discrepancy is due to the different Tensor Core architectures on these GPUs. lsit possible to disable the Tensor Core module in cuDNN to address this accuracy issueduring convolution operations?
My routine is as follows:
import torch
import torch.nn.functional as F
import ctypes
import numpy as np
from ctypes import c_void_p, c_int, c_float, c_size_t
import random
torch.random.manual_seed(42)
np.random.seed(42)
random.seed(42)
# load cuDNN
cudnn = ctypes.cdll.LoadLibrary('cudnn64_8.dll')
CUDNN_TENSOR_NCHW = 0
CUDNN_DATA_FLOAT = 0
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM = 0
CUDNN_CONVOLUTION = 0
CUDNN_CROSS_CORRELATION = 1
cudnn_handle = c_void_p()
e = cudnn.cudnnCreate(ctypes.byref(cudnn_handle))
assert e == 0, f"Error in cudnnCreate: {e}"
input_tensor = torch.randn(1, 64, 512, 512, device='cuda', dtype=torch.float32)
kernel = torch.randn(64, 64, 3, 3, device='cuda', dtype=torch.float32)
output_torch = F.conv2d(input_tensor, kernel)
output_torch = output_torch.cpu().numpy()
# print(output_torch)
output_array = np.zeros(output_torch.shape, dtype=np.float32)
input_desc = c_void_p()
e = cudnn.cudnnCreateTensorDescriptor(ctypes.byref(input_desc))
assert e == 0, f"Error in cudnnCreateTensorDescriptor: {e}"
e = cudnn.cudnnSetTensor4dDescriptor(input_desc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, *input_tensor.shape)
assert e == 0, f"Error in cudnnSetTensor4dDescriptor: {e}"
filter_desc = c_void_p()
e = cudnn.cudnnCreateFilterDescriptor(ctypes.byref(filter_desc))
assert e == 0, f"Error in cudnnCreateFilterDescriptor: {e}"
e = cudnn.cudnnSetFilter4dDescriptor(filter_desc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, *kernel.shape)
assert e == 0, f"Error in cudnnSetFilter4dDescriptor: {e}"
conv_desc = c_void_p()
e = cudnn.cudnnCreateConvolutionDescriptor(ctypes.byref(conv_desc))
assert e == 0, f"Error in cudnnCreateConvolutionDescriptor: {e}"
e = cudnn.cudnnSetConvolution2dDescriptor(conv_desc, 0, 0, 1, 1, 1, 1, CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT)
assert e == 0, f"Error in cudnnSetConvolution2dDescriptor: {e}"
e = cudnn.cudnnSetConvolutionMathType(conv_desc, 3)
output_desc = c_void_p()
e = cudnn.cudnnCreateTensorDescriptor(ctypes.byref(output_desc))
assert e == 0, f"Error in cudnnCreateTensorDescriptor: {e}"
e = cudnn.cudnnSetTensor4dDescriptor(output_desc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, *output_array.shape)
assert e == 0, f"Error in cudnnSetTensor4dDescriptor: {e}"
d_input = input_tensor.clone()
d_kernel = kernel.clone()
d_output = torch.zeros(output_array.shape, device='cuda', dtype=torch.float32)
workspace_size = c_size_t()
cudnn.cudnnGetConvolutionForwardWorkspaceSize(
cudnn_handle,
input_desc,
filter_desc,
conv_desc,
output_desc,
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
ctypes.byref(workspace_size)
)
workspace = torch.cuda.ByteTensor(workspace_size.value)
alpha = c_float(1.0)
beta = c_float(0.0)
e = cudnn.cudnnConvolutionForward(
cudnn_handle,
ctypes.byref(alpha),
input_desc,
c_void_p(d_input.data_ptr()),
filter_desc,
c_void_p(d_kernel.data_ptr()),
conv_desc,
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
c_void_p(workspace.data_ptr()),
workspace_size,
ctypes.byref(beta),
output_desc,
c_void_p(d_output.data_ptr())
)
assert e == 0, f"Error in cudnnConvolutionForward: {e}"
output_array = d_output.cpu().numpy()
# print("cuDNN 3x3 Convolution Output:")
print(np.abs(output_torch - output_array).max())
# print(output_array)
cudnn.cudnnDestroyTensorDescriptor(input_desc)
cudnn.cudnnDestroyFilterDescriptor(filter_desc)
cudnn.cudnnDestroyConvolutionDescriptor(conv_desc)
cudnn.cudnnDestroyTensorDescriptor(output_desc)
cudnn.cudnnDestroy(cudnn_handle)
“In the code below, I called the PyTorch convolution method and directly imported the cuDNN DLL to call the cuDNN convolution function separately. I controlled the input Tensor and convolution method to be the same, and then calculated and output the maximum difference between the final output results of the two methods.”
This translation accurately captures the technical details and intent of the original Chinese sentence, describing the comparison between PyTorch’s convolution implementation and a direct call to cuDNN’s convolution function, with an emphasis on controlling the inputs and comparing the outputs.