In my PyTorch CUDA extension, I require an atomicAdd. But I get the following error:
error: no instance of overloaded function "atomicAdd" matches the argument list argument types are: (c10::Half *, c10::Half)
The kernel is dispatched using PyTorch’s AT_DISPATCH_FLOATING_TYPES_AND_HALF macro; it only compiles when I change it to AT_DISPATCH_FLOATING_TYPES, indicating that it works with float and double, but not the half datatype.
My code looks as follows:
template <typename scalar_t>
__global__ void f(torch::PackedTensorAccessor32<scalar_t, 2, torch::RestrictPtrTraits> x, ...) {
// ...
atomicAdd(&x[i][j], y);
// ...
}
I tried:
System info:
Driver Version: 510.47.03
CUDA Version: 11.6
Model: Nvidia GeForce GTX 1070
It would be nice if someone could help.
The following overloads for atomicAdd exist. Programming Guide :: CUDA Toolkit Documentation
int atomicAdd(int* address, int val);
unsigned int atomicAdd(unsigned int* address,
unsigned int val);
unsigned long long int atomicAdd(unsigned long long int* address,
unsigned long long int val);
float atomicAdd(float* address, float val);
double atomicAdd(double* address, double val);
__half2 atomicAdd(__half2 *address, __half2 val);
__half atomicAdd(__half *address, __half val);
__nv_bfloat162 atomicAdd(__nv_bfloat162 *address, __nv_bfloat162 val);
__nv_bfloat16 atomicAdd(__nv_bfloat16 *address, __nv_bfloat16 val);
If c10::Half is compatible to one of those types, you can just use typecasts. Otherwise, you need to implement your own compare-and-swap-based atomicAdd.
Thank you for the answer! That link was helpful. It says:
The 16-bit __half floating-point version of atomicAdd() is only supported by devices of compute capability 7.x and higher.
I had already tried overloading the function and then casting to __half in the case where scalar_t == c10::Half. But it turns out that my GeForce GTX 1070 does not support __half atomicAdd(__half *address, __half val) as it has compute capability 6.1 as shown here.
Do you know of a way to work around the issue? How do I implement my own atomicAdd for__half values?
It is explained in the linked documentation, right above the atomicAdd overloads.
Thank you! It took me a while, but I eventually figured out the following code (tested and working for my purpose):
#ifdef __CUDA_ARCH__
#if __CUDA_ARCH__ < 700
// adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh
__device__ __forceinline__ void atomicAdd(c10::Half* address, c10::Half val) {
unsigned int *address_as_ui = reinterpret_cast<unsigned int *>(reinterpret_cast<char *>(address) - (reinterpret_cast<size_t>(address) & 2));
unsigned int old = *address_as_ui;
unsigned int assumed;
do {
assumed = old;
unsigned short hsum = reinterpret_cast<size_t>(address) & 2 ? (old >> 16) : (old & 0xffff);
hsum += val;
old = reinterpret_cast<size_t>(address) & 2
? (old & 0xffff) | (hsum << 16)
: (old & 0xffff0000) | hsum;
old = atomicCAS(address_as_ui, assumed, old);
// Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
} while (assumed != old);
}
#endif
#endif