TensorRT Sigmoid activation function produces slightly different result from original TensorFlow sigmoid op.

I’m converting a TensorFlow graph to TensorRT engine. For the same input, TensorFlow graph and TensorRT engine produce identical result up to an tf.nn.sigmoid op. But the output of the sigmoid function differs slightly between TF graph and TensorRT engine. I wonder if this is to be expected.

For example, given an input:
[8.879764 -8.724520 -10.623482 -11.822342 -12.868923 -11.805139 -13.092369 -11.573037 -11.112819 -11.025951]

tf.nn.sigmoid() produces:
[0.99986076 0.00016252598 0.000024337136 0.0000073386827 0.0000025768695 0.0000074660811 0.000002060895 0.0000094165052 0.0000014919674 0.00001.6273782]

while TensorRT engine produces:
[0.999949 0.000442 0.000066 0.000020 0.000007 0.000020 0.000006 0.000026 0.000041 0.000044]

I think this is normal (unless you have disabled reduced precision) as one of the optimizations that TensorRT does is to use ‘FP16 and INT8 reduced precision calibration’ and hence you see the truncated, less precise results with TensorRT than with TF which uses full-precision (FP32).

Can go through this: [url]https://devblogs.nvidia.com/tensorrt-3-faster-tensorflow-inference/[/url] to learn more about TRT.

But in this case, I’m not using reduced precision. I’m creating tensorRT engine using the buildCudaEngine() function in C++ API. If I’m not mistaken, it creates TensorRT engine in 32bit precision.

I see, in that case, it looks abnormal. Maybe you can query the builder to know for sure if the engine is created in FP32. Otherwise, I think NVidia person might have more information for you.

Hi,

We have tested the tf.nn.sigmoid() op with TensorRT but cannot reproduce this issue.
Please remember to convert the input to np.float32 to reserve full precision.

data = data.astype(np.float32)

Here is our testing code for your reference:

from tensorrt.parsers import uffparser
import tensorflow as tf
import tensorrt as trt
import pycuda.driver as cuda
import numpy as np
import uff


MAX_WORKSPACE = 1 << 20
MAX_BATCHSIZE = 1
G_LOGGER = trt.infer.ConsoleLogger(trt.infer.LogSeverity.INFO)

inputs = tf.placeholder(dtype=tf.float32, shape=[1,10])
output = tf.nn.sigmoid(inputs, name='out')

data = np.expand_dims(np.array([8.879764, -8.724520, -10.623482, -11.822342, -12.868923, -11.805139, -13.092369, -11.573037, -11.112819, -11.025951]), axis=0)
data = data.astype(np.float32)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    tf_result = sess.run(output,feed_dict={inputs:data})

    graphdef = tf.get_default_graph().as_graph_def()
    frozen_graph = tf.graph_util.convert_variables_to_constants(sess, graphdef, ['out'])
    tf_model = tf.graph_util.remove_training_nodes(frozen_graph)

uff_model = uff.from_tensorflow(tf_model, ['out'])

parser = uffparser.create_uff_parser()
parser.register_input("Placeholder", (1,10,1), 0)
parser.register_output("out")

engine = trt.utils.uff_to_trt_engine(G_LOGGER, uff_model, parser, MAX_BATCHSIZE, MAX_WORKSPACE)
parser.destroy()

runtime = trt.infer.create_infer_runtime(G_LOGGER)
context = engine.create_execution_context()

trt_result = cuda.pagelocked_empty(10, dtype=np.float32)

d_input = cuda.mem_alloc(10 * data.dtype.itemsize)
d_output = cuda.mem_alloc(10 * trt_result.dtype.itemsize)

bindings = [int(d_input), int(d_output)]
stream = cuda.Stream()

cuda.memcpy_htod_async(d_input, data, stream)
context.enqueue(1, bindings, stream.handle, None)
cuda.memcpy_dtoh_async(trt_result, d_output, stream)

print('TensorFlow:')
print(tf_result)
print('\nTensorRT:')
print(trt_result)

Thanks.

Thank you for looking into this and providing the sample code. I can confirm that I get identical result in this isolated test case you provided. However, when the sigmoid layer is part of a larger network, I still see the difference between TensorFlow and TensorRT, even though the output from previous conv layer is identical between TensorFlow and TensorRT. Let me see if I can extract a test case that exhibit the problem.

Hi,

Have you tested our latest TensorRT package?
If not, could you try if this issue also occurs on TensorRT 3.0.4?

Thanks.

I’m seeing this on 3.0.4.

Hi,

It really help if you can extract a simple test case to reproduce the issue you met.
Will wait for your update and discuss with our internal team for further suggestion.

Thanks.