Disabling TF32 in cuDNN at runtime on Ampere

I’m investigating an issue where I have a model trained on a system using an RTX 6000 GPU that I’m using for inferencing on another system with an RTX A6000 installed. The model is converted at runtime to a DAG built around cuDNN. It appears that the use of the TF32 floating point format during inferencing with an RTX A6000 on a model trained using FP32 causes enough errors to accumulate to make the output of the inference engine unsatisfactory. I’ve verified that using cuDNN 7.6 for training vs cuDNN 8.1 for inferencing isn’t the cause of this problem by using an RTX 6000 in both environments. Is there a way to disable the TF32 computations at run-time on demand for a specific cuDNN context? It looks like PyTorch has a way to do this, but I can’t find it documented in the cuDNN API.

Edit: It looks like I can call cudnnSetConvolutionMathType (…, CUDNN_FMA_MATH) and this will supposedly not use TF32, but results are binary identical whether or not I call that.

Hi @jbaumgart ,
Yes, CUDNN_FMA_MATH helps you disabling the TF32 computation.
However on the results, can you please share more details with us?


Execution time is the same whether or not I set CUDNN_FMA_MATH.

Initialization of convolution node:

cuCheck (cudnnCreateConvolutionDescriptor (&m_convolution_descriptor));
printf ("Using CUDNN_FMA_MATH\n");
cuCheck (cudnnSetConvolutionMathType (m_convolution_descriptor, CUDNN_FMA_MATH));

Later execution:

cuCheck (cudnnConvolutionForward (m_cudnn_ctxt,
                                 &alpha, m_incomings[0]->getOutputTensorDescriptor(),
                                 m_incomings[0]->getOutput(), // input is the output of previous layer
                                 m_filter_descriptor, m_d_filters,
                                 m_convolution_descriptor, m_convolution_algorithm_descriptor,
                                 m_cudnn_dWksp, m_workspace_bytes,
                                 &beta, m_output_descriptor, m_d_output));

Same result and timing whether or not I include the cudnnSetConvolutionMathType() call.

Hi @jbaumgart ,
Apologies for the delay, are you still facing the issue?

No longer facing this issue. Thanks for checking.