PyTorch normalization in Deepstream config

Hi!

I realize that image normalization in deepstream is controlled by net-scale-factor and offsets. I have seen in sample configs the idea of just using net-scale-factor=0.0039215697906911373, which pretty much defines a division by 255.

However, I have a classic PyTorch normalization in my model: mean= [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]. How do I transform this to DeepStream?

Took a look at RetinaNet example (retinanet-examples/infer_config_batch1.txt at main · NVIDIA/retinanet-examples · GitHub), looks that they use net-scale-factor=0.017352074, offsets=123.675;116.28;103.53 with the same PyTorch normalization as I do. Is that correct?

Hi,

The normalization equation used in Deepstream looks like this:
https://docs.nvidia.com/metropolis/deepstream/dev-guide/text/DS_plugin_gst-nvinfer.html

y = net-scale-factor*(x-offsets)

The net-scale-factor is used in the same way as the mean value in pyTorch.
However, we don’t have a configure parameter for std.

Here is a discussion by calculating the corresponding mean and offset value via std.
It’s recommended to check it first:

Thanks.

Hi,

but I don’t understand the calculations:

If I have mean= [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225], then I can set average value for std, e.g. 0.226, and calculate net-scale-factor = 1/128/0.578* 0.226 = 0.0030547145328
But what to do with mean values [0.485, 0.456, 0.406]? Should I add offsets parameter?

Hi rostislav.etc,

Please help to open a new topic for your issue. Thanks

Deepstream config has net-scale-factor which you can essentially use to do 1/std part of normalisation (though not channel-wise). There is also offsets parameters which can be used to do (x- mean) part of normalisation (channel-wise).

So now we just need to factor in the fact that e.g. Pytorch uses pixel values scaled to [0,1], while Deepstream does not scale and uses original [0,255] range.

Thus if we have mean=[0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225] transformations on input scaled to [0,1] during training we need to unscale them back to [0,255] range. We can do mean channel-wise and use those values for offsets.

np.array([0.485, 0.456, 0.406])*255
array([123.675, 116.28 , 103.53 ])

For net-scale-factor we can unscale the mean value across channels of our std = [0.229, 0.224, 0.225] that was used in training.

np.array([0.229, 0.224, 0.225]).mean()*255
57.63

And our net-scale-factor is going to be 1/unscaled std = 1/57.63 = 0.01735207357279195.

And that are the same values used in RetinaNet example mentioned above https://github.com/NVIDIA/retinanet-examples/blob/master/extras/deepstream/deepstream-sample/infer_config_batch1.txt.

By applying those calculated net-scale-factor and offsets our models show the same performance during DeepStream inference as when we test them within PyTorch framework.

Hope that helps!

15 Likes