Overloading for Float and Half: more than one operator matches these operands

Hi, I’m trying to write some TensorRT plugins and I’m encountering an issue where if I use templated functions to be compatible with either FP32 or FP16, I get the issue as described “error: more than one operator matches these operands:” or “error: more than one instance of overloaded function matches the argument list:”. Is there any way I can easily resolve this without dropping half support? Or do I have to explicitly write out float and __half versions or manually cast everything to the “master” type? The kernel is a modified implementation from PyTorch.

Currently running CUDA 10.2

Cheers

Example code of why I want to use templates, full source can be found here: https://github.com/5had3z/stereo-to-all. If anyone else has a TRT implementation of some of the PWC-Net operations (correlation|grid sample) please share, otherwise if this is new(?) feel free to take when I’m done.

int GridSamplerPlugin::enqueue(int batchSize, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream)
{
    const int64_t count = batchSize * m_input_dims.d[1] * m_input_dims.d[2];

    if (m_datatype == nvinfer1::DataType::kFLOAT)
    {
        grid_sampler_kernel<float><<<GET_BLOCKS(count), CUDA_NUM_THREADS, 0, stream>>>(count,
            reinterpret_cast<const float*>(inputs[0]), m_input_dims.d[0], m_input_dims.d[1], m_input_dims.d[2],
            reinterpret_cast<const float*>(inputs[1]),
            reinterpret_cast<float*>(outputs[0]), m_output_dims.d[1], m_output_dims.d[2],
            m_interpolation_mode, m_padding_mode, m_align_corners);
    }
    else if (m_datatype == nvinfer1::DataType::kHALF)
    {
        grid_sampler_kernel<__half><<<GET_BLOCKS(count), CUDA_NUM_THREADS, 0, stream>>>(count,
            reinterpret_cast<const __half*>(inputs[0]), m_input_dims.d[0], m_input_dims.d[1], m_input_dims.d[2],
            reinterpret_cast<const __half*>(inputs[1]),
            reinterpret_cast<__half*>(outputs[0]), m_output_dims.d[1], m_output_dims.d[2],
            m_interpolation_mode, m_padding_mode, m_align_corners);
    }

    return cudaGetLastError();
}