I implemented a network with cuda, but I want to slice the output result in fp16, but the result is nan, the result of fp32 is normal, what is wrong with my fp16 implementation?
// my implementation
ouput.size(batch, 193, 768)
my_output= ouput[:, 0]
output.size(batch, 768)
// int main.cc , my func is invokesplitout
int n = embed_dim_;
int s = seq_len;
int m = input_batch_size * s;
invokesplitout(ouput, my_output, m, n, s, stream_);
// invokesplitout
// I have tried to use const half* __restrict in, const half* __restrict out when it's fp16, but also nan. Anyway the fp32 is normal.
template<typename T>
__global__ void splitout(const T* in,
T* out,
const int m,
const int n,
const int s)
{
for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x)
{
int col_idx = id % n;
int row_idx = id / n / s;
int tar_row_idx = row_idx * n;
int tar_idx = tar_row_idx + col_idx;
out[id] = in[tar_idx];
}
}
template<>
__global__ void splitout(const half* in,
half* out,
const int m,
const int n,
const int s)
{
half2* out_ptr = (half2*)out;
const half2* in_ptr = (half2*)in;
for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x)
{
int col_idx = id % n;
int row_idx = id / n / s;
int tar_row_idx = row_idx * n;
int tar_idx = tar_row_idx + col_idx;
half2 d1 = in_ptr[tar_idx];
out_ptr [id] = d1;
// printf("value=%lu\n", d1);
}
}
template<typename T>
void invokesplitout(
const T* in, T* out, const int m, const int n, const int s, cudaStream_t stream)
{
const int data_type_factor = 4 / sizeof(T); // 1 for fp32, 2 for fp16
dim3 block, grid;
if (n / 4 / data_type_factor <= 1024) {
block.x = n / 4 / data_type_factor;
grid.x = m;
}
else {
block.x = 1024;
grid.x = (m * n + 1023) / 1024;
}
// splitout<<<grid, block, 0, stream>>>(in, out, m, n, s);
splitout<<<grid, block, 0, stream>>>(in, out, m, n / data_type_factor, s);
}
template void invokesplitout(const float* in, float* out, const int m, const int n, const int s, cudaStream_t stream);
template void invokesplitout(const half* in, half* out, const int m, const int n, const int s, cudaStream_t stream);