I want to convert a mxnet model to tensorRT through caffe parser (mxnet->caffe->tensorRT). The issue is with caffe’s padding convention.
Assume the input is 28x28 (HxW), pooling kernel is 3x3, stride is 2x2, pad is 0, for caffe, the pooling output size is 14x14, for mxnet it’s 13x13.
I notice the API: nvinfer1::INetworkDefinition::setPoolingOutputDimensionsFormula(IOutputDimensionsFormula *formula), and the default formula in each dimension is (inputDim + padding * 2 - kernelSize) / stride + 1. Following this convention, the pooling output size should be 13x13, rather than 14x14.
How can I ensure that pooling layer follows the default convention even if caffe parser is used?
Could you share the way you calculate Caffe output dimension?
Based on the Caffe document: http://caffe.berkeleyvision.org/tutorial/layers/convolution.html
h_o = (h_i + 2 * pad_h - kernel_h) / stride_h + 1 is also 13x13.
Please let us know if we miss anything.
caffe’s pooling convention is:
pooled_height_ = static_cast<int>(ceil(static_cast<float>(height_ + 2 * pad_h_ - kernel_h_) / stride_h_)) + 1;
pooled_width_ = static_cast<int>(ceil(static_cast<float>(width_ + 2 * pad_w_ - kernel_w_) / stride_w_)) + 1;
mxnet uses floor, caffe uses ceil. This is how 13 vs 14 happens.
I think tensorRT’s default setting is using floor, right?
These are the two definitions.
One is known as ‘valid’ convolution and the other one is known as ‘full’ convolution.
valid: (mxnet default)::
f(x, k, p, s) = floor((x+2p-k)/s)+1
full: which is compatible with Caffe::
f(x, k, p, s) = ceil((x+2p-k)/s)+1
Yes, so the question is how can I choose between the two padding conventions in tensorRT?
Specifically, could you please give an example on the usage of nvinfer1::INetworkDefinition::setPoolingOutputDimensionsFormula(IOutputDimensionsFormula *formula)?
Sorry I do not know that, but if you have trained the model in mxnet. Then you can train it with again with the ‘full’ option and then port it to caffe.
We don’t have a sample to demonstrate setPoolingOutputDimensionsFormula() API, but you can check this document for some information:
Thanks and sorry for the inconvenience.