Trouble in converting non square grid in YOLO Network to tensorrt via DeepStream

Hi all.

YOLO-Lite: I modified The trt_utils.cpp so that it can convert yolo-lite(a modified verion of tiny-yolov2 with no batch normalization layer) to TensorRT engine and this run succesfully

I want to use yolo-lite model with non-square grid size. As you know default grid size in YOLO networks has equal grid size, such as 99 , 1111 , 1313 , 77(for yolo-lite). But for some types of object detection it’s better to change grid size to non-square: like 3010 or others.
I trained a yolo-lite with 7
5 grid size successfully, but when I want to convert it in deepstream to TensorRT an error happened from line 429 in yolo.cpp :

assert(m_inputW==m_InputH);

anf it’s because outr input size is 160*240(different input size)

SO I commented that line of code. Also I commented this line in yolo.cpp file

assert(prevTensorDims.d[1]==prevTensorDims.d[2])

.

After that I modified following line of trt_utils.h in class YoloTinyMaxpoolPaddingFormula as showned differently from other codes(Because know two different size for output dimension)

class YoloTinyMaxpoolPaddingFormula : public nvinfer1::IOutputDimensionsFormula
{

private:
std::set<std::string> m_SamePaddingLayers;

nvinfer1::DimsHW compute(nvinfer1::DimsHW inputDims, nvinfer1::DimsHW kernelSize,
                         nvinfer1::DimsHW stride, nvinfer1::DimsHW padding,
                         nvinfer1::DimsHW dilation, const char* layerName) const override
{
    //assert(inputDims.d[0] == inputDims.d[1]);
    assert(kernelSize.d[0] == kernelSize.d[1]);
    assert(stride.d[0] == stride.d[1]);
    assert(padding.d[0] == padding.d[1]);

    int outputDim1;
    int outputDim2;
    // Only layer maxpool_12 makes use of same padding
    if (m_SamePaddingLayers.find(layerName) != m_SamePaddingLayers.end())
    {
        outputDim1 = (inputDims.d[0] + 2 * padding.d[0]) / stride.d[0];
     <i> <u><b>outputDim2 = (inputDims.d[1] + 2 * padding.d[1]) / stride.d[1];</b></u></i>

    }
    // Valid Padding
    else
    {
        outputDim1 = (inputDims.d[0] - kernelSize.d[0]) / stride.d[0] + 1;
       <i> <u><b>outputDim2 = (inputDims.d[1] - kernelSize.d[1]) / stride.d[1] + 1;</b></u></i>
    }
    <i> <u><b>return nvinfer1::DimsHW{outputDim1, outputDim2};</b></u></i>
}

public:
void addSamePaddingLayer(std::string input) { m_SamePaddingLayers.insert(input); }
};

The code runs with no error and not even a warning but once executed, the bounding boxes are localized wrongly.

Sample of output:


https://ibb.co/SyQbHDw

YOLO-Lite Network: on My training system(GPU 2080Ti) using Darknet Pjrddie

compute_capability = 750, cudnn_half = 1 
   layer   filters  size/strd(dil)      input                output
   0 conv     16       3 x 3/ 1    160 x 224 x   3 ->  160 x 224 x  16 0.031 BF
   1 max                2x 2/ 2    160 x 224 x  16 ->   80 x 112 x  16 0.001 BF
   2 conv     32       3 x 3/ 1     80 x 112 x  16 ->   80 x 112 x  32 0.083 BF
   3 max                2x 2/ 2     80 x 112 x  32 ->   40 x  56 x  32 0.000 BF
   4 conv     64       3 x 3/ 1     40 x  56 x  32 ->   40 x  56 x  64 0.083 BF
   5 max                2x 2/ 2     40 x  56 x  64 ->   20 x  28 x  64 0.000 BF
   6 conv    128       3 x 3/ 1     20 x  28 x  64 ->   20 x  28 x 128 0.083 BF
   7 max                2x 2/ 2     20 x  28 x 128 ->   10 x  14 x 128 0.000 BF
   8 conv    128       3 x 3/ 1     10 x  14 x 128 ->   10 x  14 x 128 0.041 BF
   9 max                2x 2/ 2     10 x  14 x 128 ->    5 x   7 x 128 0.000 BF
  10 conv    256       3 x 3/ 1      5 x   7 x 128 ->    5 x   7 x 256 0.021 BF
  11 conv     35       1 x 1/ 1      5 x   7 x 256 ->    5 x   7 x  35 0.001 BF
  12 detection

What was shown by converting to TensorRT using DeepStream(in terminal):


https://ibb.co/Sn21mJC

Loading pre-trained weights...
Loading complete!
Total Number of weights read : 549187
layer inp_size out_size weightPtr
(1) conv-bn-leaky 3 x 224 x 160 16 x 224 x 160 448
(2) maxpool 16 x 224 x 160 16 x 112 x 80 448
(3) conv-bn-leaky 16 x 112 x 80 32 x 112 x 80 5088
(4) maxpool 32 x 112 x 80 32 x 56 x 40 5088
(5) conv-bn-leaky 32 x 56 x 40 64 x 56 x 40 23584
(6) maxpool 64 x 56 x 40 64 x 28 x 20 23584
(7) conv-bn-leaky 64 x 28 x 20 128 x 28 x 20 97440
(8) maxpool 128 x 28 x 20 128 x 14 x 10 97440
(9) conv-bn-leaky 128 x 14 x 10 128 x 14 x 10 245024
(10) maxpool 128 x 14 x 10 128 x 7 x 5 245024
(11) conv-bn-leaky 128 x 7 x 5 256 x 7 x 5 540192
(12) conv-linear 256 x 7 x 5 35 x 7 x 5 549187
(13) region 35 x 7 x 5 35 x 7 x 5 549187
Anchors are being converted to network input resolution i.e. Anchors x 22 (stride)
Output blob names :
region_13
Total number of layers: 21
Total number of layers on DLA: 0
Building the TensorRT Engine...
Building complete!

Hi,

Our sample only support symmetric input.
It requires more update for the non-symmetric use case.

Ex:

else if (m_configBlocks.at(i).at("type") == "yolo")
{
    nvinfer1::Dims prevTensorDims = previous->getDimensions();
    assert(prevTensorDims.d[1] == prevTensorDims.d[2]);
    TensorInfo& curYoloTensor = m_OutputTensors.at(outputTensorCount);
    <b>curYoloTensor.gridSize = prevTensorDims.d[1];</b>
    curYoloTensor.stride = m_InputW / curYoloTensor.gridSize;
    ...

Thanks.

Hi,

Can you also post your cfg file here ?

i am also facing similar issue, i am unable to load the network

Here is my cfg file

[net]

# Testing
# batch=1
# subdivisions=1

# Training
batch=64
subdivisions=8

height=128
width=288
channels=3
momentum=0.9
decay=0.0005
angle=0
saturation = 1.5
exposure = 1.5
hue=.1

learning_rate=0.001
#burn_in=1200
max_batches = 25000
policy=steps
steps=2000,6000,9000,22000
scales=.1,.1,.1,.1

[convolutional]
batch_normalize=1
filters=16
size=3
stride=1
pad=1
activation=leaky

[maxpool]
size=2
stride=2

[convolutional]
batch_normalize=1
filters=32
size=3
stride=1
pad=1
activation=leaky

[maxpool]
size=2
stride=2

[convolutional]
batch_normalize=1
filters=96
size=3
stride=1
pad=1
activation=leaky

[convolutional]
batch_normalize=1
filters=32
size=1
stride=1
pad=1
activation=leaky

[maxpool]
size=2
stride=2

[convolutional]
batch_normalize=1
filters=160
size=3
stride=1
pad=1
activation=leaky

[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky

[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=64
activation=leaky

[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=192
activation=leaky

[convolutional]
size=1
stride=1
pad=1
filters=12
activation=linear

[yolo]
iou_loss= giou
mask = 0,1
anchors = 33,35,  22,53
classes=1
num=2
jitter=.3

random=0

edit:

After making some changes in the nvdsinfer_custom_impl_Yolo code i was able to load and run the network.
Changes i made were removal of some size check asserts, calculation and use of gridSize and stride in both H and W dimension and some others.

But the network produced incorrect output and in cases did not produce any output.
What i noticed was that the network info printed by darknet and nvdsinfer differed, after the first maxpool the output dimensions printed by nvdsinfer were same in size (which should not be happening)

Darknet output

https://imgur.com/a/evM8ABl

NvdsInfer output

https://imgur.com/a/H41W1a8

The output dimensions returned after netAddMaxpool are wrong. nvinfer1::IPoolingLayer is doing something wrong.

[net]
# Testing
batch=128
subdivisions=1
# batch=64
# subdivisions=32
# Training
# batch=32
# subdivisions=1
width=160
height=224
channels=3
momentum=0.9
decay=0.0005
angle=0
saturation = 1.5
exposure = 1.5
hue=.1

learning_rate=0.001
burn_in=1000
max_batches = 1500000
policy=steps
steps  = 250000,300000,700000,800000
scales = 10    ,.1    ,.1    ,.1
#steps  = 124100,250000,400000,450000
#scales = 10    ,.1    ,.1    ,.1
#250001,10    ,
[convolutional]
batch_normalize=0
filters=16
size=3
stride=1
pad=1
activation=leaky

[maxpool]
size=2
stride=2

[convolutional]
batch_normalize=0
filters=32
size=3
stride=1
pad=1
activation=leaky

[maxpool]
size=2
stride=2

[convolutional]
batch_normalize=0
filters=64
size=3
stride=1
pad=1
activation=leaky

[maxpool]
size=2
stride=2

[convolutional]
batch_normalize=0
filters=128
size=3
stride=1
pad=1
activation=leaky

[maxpool]
size=2
stride=2


[convolutional]
batch_normalize=0
filters=128
size=3
stride=1
pad=1
activation=leaky

[maxpool]
size=2
stride=2

###########

[convolutional]
batch_normalize=0
size=3
stride=1
pad=1
filters=256
activation=leaky

[convolutional]
size=1
stride=1
pad=1
filters=35
activation=linear

[region]
anchors =  1.08,1.19,  3.42,4.41,  6.63,11.38,  9.42,5.11,  16.62,10.52
bias_match=1
classes=2
coords=4
num=5
softmax=1
jitter=.2
rescore=0

object_scale=5
noobject_scale=1
class_scale=1
coord_scale=1

absolute=1
thresh = .5
random=0

Dear AastaLLL,

I exactly made these modifications in the first place, however I just forgot to mention it in my first question. Still, it doesn’t work. With all these modifications(that you mentioned) the TensorRT engine is successfully created with no error. However, the cursors are not found correctly, as you can see bellow:


https://ibb.co/CHYY2VL

Hey barzanhayati

Which version of tensorrt are you using ?
Because i am also facing a similar issue, in my case after making the changes i can see that ouput dimension returned by nvinfer1::IPoolingLayer is wrong. But in your case i can see that maxpool returns correct dimensions.

hi NvCJR, AastaLLL

Can you please confirm if there is an issue in tensorrt maxpool or not Or are my findings wrong.

I am in middle of creating a app on jetson nano which i need to present to my team so that we can decide on which hardware to go ahead with.

TensorRT: 5.1.5

Hi,

Does anyone from Nvidia have any update on this issue?

Hi,

Couple of suggestions -

  1. When you have a question, please create a new thread and feel free to link the similar threads so it’s easy to respond.

  2. The yolo samples provided in the SDK are expected to work only for the standard models.

After making some changes in the nvdsinfer_custom_impl_Yolo code i was able to load and run the network.
Changes i made were removal of some size check asserts, calculation and use of gridSize and stride in both H and W dimension and some others.

These checks are in place to make sure the network builder works for square inputs only. We will provide a sample for non square inputs in a future release, so until then feel free to modify the current yolo network builder source code to fit your needs. Regarding maxpool layer, there are no bugs in its implementation. If your maxpool layer output does not match the darknet output, i would suggest looking at the padding type and the stride used. Based on the changes you’ve made the stride / padding type may not be matching the darknet implementation.

If anyone is still looking for a working solution for asymmetric models, I’ve rolled a patch below:

diff --git a/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/kernels.cu b/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/kernels.cu
index 45032f0..43f1906 100755
--- a/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/kernels.cu
+++ b/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/kernels.cu
@@ -17,20 +17,20 @@
 
 inline __device__ float sigmoidGPU(const float& x) { return 1.0f / (1.0f + __expf(-x)); }
 
-__global__ void gpuYoloLayerV3(const float* input, float* output, const uint gridSize, const uint numOutputClasses,
+__global__ void gpuYoloLayerV3(const float* input, float* output, const uint gridSizeX,  const uint gridSizeY, const uint numOutputClasses,
                                const uint numBBoxes)
 {
     uint x_id = blockIdx.x * blockDim.x + threadIdx.x;
     uint y_id = blockIdx.y * blockDim.y + threadIdx.y;
     uint z_id = blockIdx.z * blockDim.z + threadIdx.z;
 
-    if ((x_id >= gridSize) || (y_id >= gridSize) || (z_id >= numBBoxes))
+    if ((x_id >= gridSizeX) || (y_id >= gridSizeY) || (z_id >= numBBoxes))
     {
         return;
     }
 
-    const int numGridCells = gridSize * gridSize;
-    const int bbindex = y_id * gridSize + x_id;
+    const int numGridCells = gridSizeX * gridSizeY;
+    const int bbindex = y_id * gridSizeX + x_id;
 
     output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]
         = sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]);
@@ -54,23 +54,23 @@ __global__ void gpuYoloLayerV3(const float* input, float* output, const uint gri
     }
 }
 
-cudaError_t cudaYoloLayerV3(const void* input, void* output, const uint& batchSize, const uint& gridSize,
+cudaError_t cudaYoloLayerV3(const void* input, void* output, const uint& batchSize, const uint& gridSizeX, const uint& gridSizeY,
                             const uint& numOutputClasses, const uint& numBBoxes,
                             uint64_t outputSize, cudaStream_t stream);
 
-cudaError_t cudaYoloLayerV3(const void* input, void* output, const uint& batchSize, const uint& gridSize,
+cudaError_t cudaYoloLayerV3(const void* input, void* output, const uint& batchSize, const uint& gridSizeX, const uint& gridSizeY,
                             const uint& numOutputClasses, const uint& numBBoxes,
                             uint64_t outputSize, cudaStream_t stream)
 {
     dim3 threads_per_block(16, 16, 4);
-    dim3 number_of_blocks((gridSize / threads_per_block.x) + 1,
-                          (gridSize / threads_per_block.y) + 1,
+    dim3 number_of_blocks((gridSizeX / threads_per_block.x) + 1,
+                          (gridSizeY / threads_per_block.y) + 1,
                           (numBBoxes / threads_per_block.z) + 1);
     for (unsigned int batch = 0; batch < batchSize; ++batch)
     {
         gpuYoloLayerV3<<<number_of_blocks, threads_per_block, 0, stream>>>(
             reinterpret_cast<const float*>(input) + (batch * outputSize),
-            reinterpret_cast<float*>(output) + (batch * outputSize), gridSize, numOutputClasses,
+            reinterpret_cast<float*>(output) + (batch * outputSize), gridSizeX, gridSizeY, numOutputClasses,
             numBBoxes);
     }
     return cudaGetLastError();
diff --git a/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/nvdsparsebbox_Yolo.cpp b/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/nvdsparsebbox_Yolo.cpp
index 4226027..45399ef 100755
--- a/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/nvdsparsebbox_Yolo.cpp
+++ b/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/nvdsparsebbox_Yolo.cpp
@@ -29,7 +29,7 @@
 #include <iostream>
 #include <unordered_map>
 
-static const int NUM_CLASSES_YOLO = 80;
+static const int NUM_CLASSES_YOLO = 13;
 
 extern "C" bool NvDsInferParseCustomYoloV3(
     std::vector<NvDsInferLayerInfo> const& outputLayersInfo,
@@ -167,22 +167,22 @@ nmsAllClasses(const float nmsThresh,
 static std::vector<NvDsInferParseObjectInfo>
 decodeYoloV2Tensor(
     const float* detections, const std::vector<float> &anchors,
-    const uint gridSize, const uint stride, const uint numBBoxes,
+    const uint gridSizeX, const uint gridSizeY, const uint stride, const uint numBBoxes,
     const uint numOutputClasses, const float probThresh, const uint& netW,
     const uint& netH)
 {
     std::vector<NvDsInferParseObjectInfo> binfo;
-    for (uint y = 0; y < gridSize; ++y)
+    for (uint y = 0; y < gridSizeY; ++y)
     {
-        for (uint x = 0; x < gridSize; ++x)
+        for (uint x = 0; x < gridSizeX; ++x)
         {
             for (uint b = 0; b < numBBoxes; ++b)
             {
                 const float pw = anchors[b * 2];
                 const float ph = anchors[b * 2 + 1];
 
-                const int numGridCells = gridSize * gridSize;
-                const int bbindex = y * gridSize + x;
+                const int numGridCells = gridSizeX * gridSizeY;
+                const int bbindex = y * gridSizeX + x;
                 const float bx
                     = x + detections[bbindex + numGridCells * (b * (5 + numOutputClasses) + 0)];
                 const float by
@@ -226,22 +226,22 @@ decodeYoloV2Tensor(
 static std::vector<NvDsInferParseObjectInfo>
 decodeYoloV3Tensor(
     const float* detections, const std::vector<int> &mask, const std::vector<float> &anchors,
-    const uint gridSize, const uint stride, const uint numBBoxes,
+    const uint gridSizeX, const uint gridSizeY, const uint stride, const uint numBBoxes,
     const uint numOutputClasses, const float probThresh, const uint& netW,
     const uint& netH)
 {
     std::vector<NvDsInferParseObjectInfo> binfo;
-    for (uint y = 0; y < gridSize; ++y)
+    for (uint y = 0; y < gridSizeY; ++y)
     {
-        for (uint x = 0; x < gridSize; ++x)
+        for (uint x = 0; x < gridSizeX; ++x)
         {
             for (uint b = 0; b < numBBoxes; ++b)
             {
                 const float pw = anchors[mask[b] * 2];
                 const float ph = anchors[mask[b] * 2 + 1];
 
-                const int numGridCells = gridSize * gridSize;
-                const int bbindex = y * gridSize + x;
+                const int numGridCells = gridSizeX * gridSizeY;
+                const int bbindex = y * gridSizeX + x;
                 const float bx
                     = x + detections[bbindex + numGridCells * (b * (5 + numOutputClasses) + 0)];
                 const float by
@@ -304,8 +304,8 @@ static bool NvDsInferParseYoloV3(
     const std::vector<std::vector<int>> &masks)
 {
     const uint kNUM_BBOXES = 3;
-    static const float kNMS_THRESH = 0.3f;
-    static const float kPROB_THRESH = 0.7f;
+    static const float kNMS_THRESH = 0.4f;
+    static const float kPROB_THRESH = 0.3f;
 
     const std::vector<const NvDsInferLayerInfo*> sortedLayers =
         SortLayers (outputLayersInfo);
@@ -328,11 +328,12 @@ static bool NvDsInferParseYoloV3(
     for (uint idx = 0; idx < masks.size(); ++idx) {
         const NvDsInferLayerInfo &layer = *sortedLayers[idx]; // 255 x Grid x Grid
         assert (layer.dims.numDims == 3);
-        const uint gridSize = layer.dims.d[1];
-        const uint stride = networkInfo.width / gridSize;
+      const uint gridSizeY = layer.dims.d[1];
+      const uint gridSizeX = layer.dims.d[2];
+      const uint stride = networkInfo.height / gridSizeY;
 
         std::vector<NvDsInferParseObjectInfo> outObjs =
-            decodeYoloV3Tensor((const float*)(layer.buffer), masks[idx], anchors, gridSize, stride, kNUM_BBOXES,
+            decodeYoloV3Tensor((const float*)(layer.buffer), masks[idx], anchors, gridSizeX, gridSizeY, stride, kNUM_BBOXES,
                        NUM_CLASSES_YOLO, kPROB_THRESH, networkInfo.width, networkInfo.height);
         objects.insert(objects.end(), outObjs.begin(), outObjs.end());
     }
@@ -373,8 +374,8 @@ extern "C" bool NvDsInferParseCustomYoloV3Tiny(
         10, 14, 23, 27, 37, 58, 81, 82, 135, 169, 344, 319};
     static const std::vector<std::vector<int>> kMASKS = {
         {3, 4, 5},
-        //{0, 1, 2}}; // as per output result, select {1,2,3}
-        {1, 2, 3}};
+        {0, 1, 2}}; // as per output result, select {1,2,3}
+//        {1, 2, 3}};
 
     return NvDsInferParseYoloV3 (
         outputLayersInfo, networkInfo, detectionParams, objectList,
@@ -408,10 +409,12 @@ static bool NvDsInferParseYoloV2(
     }
 
     assert (layer.dims.numDims == 3);
-    const uint gridSize = layer.dims.d[1];
-    const uint stride = networkInfo.width / gridSize;
+    const uint gridSizeY = layer.dims.d[1];
+    const uint gridSizeX = layer.dims.d[2];
+    const uint stride = networkInfo.height / gridSizeY;
+
     std::vector<NvDsInferParseObjectInfo> objects =
-        decodeYoloV2Tensor((const float*)(layer.buffer), kANCHORS, gridSize, stride, kNUM_BBOXES,
+        decodeYoloV2Tensor((const float*)(layer.buffer), kANCHORS, gridSizeX, gridSizeY, stride, kNUM_BBOXES,
                    NUM_CLASSES_YOLO, probthreshold, networkInfo.width, networkInfo.height);
 
     objectList.clear();
diff --git a/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/trt_utils.cpp b/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/trt_utils.cpp
index 17049e8..46ee475 100755
--- a/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/trt_utils.cpp
+++ b/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/trt_utils.cpp
@@ -373,19 +373,20 @@ nvinfer1::ILayer* netAddUpsample(int layerIdx, std::map<std::string, std::string
     assert(block.at("type") == "upsample");
     nvinfer1::Dims inpDims = input->getDimensions();
     assert(inpDims.nbDims == 3);
-    assert(inpDims.d[1] == inpDims.d[2]);
+//    assert(inpDims.d[1] == inpDims.d[2]);
     int h = inpDims.d[1];
     int w = inpDims.d[2];
     int stride = std::stoi(block.at("stride"));
     // add pre multiply matrix as a constant
     nvinfer1::Dims preDims{3,
-                           {1, stride * h, w},
-                           {nvinfer1::DimensionType::kCHANNEL, nvinfer1::DimensionType::kSPATIAL,
+                           {1, stride * h, h},
+                           {nvinfer1::DimensionType::kCHANNEL,
+                            nvinfer1::DimensionType::kSPATIAL,
                             nvinfer1::DimensionType::kSPATIAL}};
-    int size = stride * h * w;
+    int size = stride * h * h;
     nvinfer1::Weights preMul{nvinfer1::DataType::kFLOAT, nullptr, size};
     float* preWt = new float;
-    /* (2*h * w)
+    /* (2*h * h)
     [ [1, 0, ..., 0],
       [1, 0, ..., 0],
       [0, 1, ..., 0],
@@ -397,12 +398,9 @@ nvinfer1::ILayer* netAddUpsample(int layerIdx, std::map<std::string, std::string
     */
     for (int i = 0, idx = 0; i < h; ++i)
     {
-        for (int s = 0; s < stride; ++s)
+        for (int j = 0; j < h * stride; ++j, ++idx)
         {
-            for (int j = 0; j < w; ++j, ++idx)
-            {
-                preWt[idx] = (i == j) ? 1.0 : 0.0;
-            }
+            preWt[idx] = (i == j) ? 1.0 : 0.0;
         }
     }
     preMul.values = preWt;
@@ -413,20 +411,20 @@ nvinfer1::ILayer* netAddUpsample(int layerIdx, std::map<std::string, std::string
     preM->setName(preLayerName.c_str());
     // add post multiply matrix as a constant
     nvinfer1::Dims postDims{3,
-                            {1, h, stride * w},
+                            {1, w, stride * w},
                             {nvinfer1::DimensionType::kCHANNEL, nvinfer1::DimensionType::kSPATIAL,
                              nvinfer1::DimensionType::kSPATIAL}};
-    size = stride * h * w;
+    size = stride * w * w;
     nvinfer1::Weights postMul{nvinfer1::DataType::kFLOAT, nullptr, size};
     float* postWt = new float;
-    /* (h * 2*w)
+    /* (w * 2*w)
     [ [1, 1, 0, 0, ..., 0, 0],
       [0, 0, 1, 1, ..., 0, 0],
       ...,
       ...,
       [0, 0, 0, 0, ..., 1, 1] ]
     */
-    for (int i = 0, idx = 0; i < h; ++i)
+    for (int i = 0, idx = 0; i < w; ++i)
     {
         for (int j = 0; j < stride * w; ++j, ++idx)
         {
@@ -441,8 +439,8 @@ nvinfer1::ILayer* netAddUpsample(int layerIdx, std::map<std::string, std::string
     post_m->setName(postLayerName.c_str());
     // add matrix multiply layers for upsampling
     nvinfer1::IMatrixMultiplyLayer* mm1
-        = network->addMatrixMultiply(*preM->getOutput(0), nvinfer1::MatrixOperation::kNONE, *input,
-                                     nvinfer1::MatrixOperation::kNONE);
+        = network->addMatrixMultiply(*preM->getOutput(0), nvinfer1::MatrixOperation::kNONE,
+                                     *input,nvinfer1::MatrixOperation::kNONE);
     assert(mm1 != nullptr);
     std::string mm1LayerName = "mm1_" + std::to_string(layerIdx);
     mm1->setName(mm1LayerName.c_str());
diff --git a/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/trt_utils.h b/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/trt_utils.h
index 97dcc5f..26e901b 100755
--- a/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/trt_utils.h
+++ b/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/trt_utils.h
@@ -46,23 +46,26 @@ private:
                              nvinfer1::DimsHW stride, nvinfer1::DimsHW padding,
                              nvinfer1::DimsHW dilation, const char* layerName) const override
     {
-        assert(inputDims.d[0] == inputDims.d[1]);
+//        assert(inputDims.d[0] == inputDims.d[1]);
         assert(kernelSize.d[0] == kernelSize.d[1]);
         assert(stride.d[0] == stride.d[1]);
         assert(padding.d[0] == padding.d[1]);
 
-        int outputDim;
+        int outputDimH;
+        int outputDimW;
         // Only layer maxpool_12 makes use of same padding
         if (m_SamePaddingLayers.find(layerName) != m_SamePaddingLayers.end())
         {
-            outputDim = (inputDims.d[0] + 2 * padding.d[0]) / stride.d[0];
+            outputDimH = (inputDims.d[0] + 2 * padding.d[0]) / stride.d[0];
+            outputDimW = (inputDims.d[1] + 2 * padding.d[1]) / stride.d[1];
         }
         // Valid Padding
         else
         {
-            outputDim = (inputDims.d[0] - kernelSize.d[0]) / stride.d[0] + 1;
+            outputDimH = (inputDims.d[0] - kernelSize.d[0]) / stride.d[0] + 1;
+            outputDimW = (inputDims.d[1] - kernelSize.d[1]) / stride.d[1] + 1;
         }
-        return nvinfer1::DimsHW{outputDim, outputDim};
+        return nvinfer1::DimsHW{outputDimH, outputDimW};
     }
 
 public:
diff --git a/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/yolo.cpp b/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/yolo.cpp
index 379694a..bbc34eb 100755
--- a/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/yolo.cpp
+++ b/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/yolo.cpp
@@ -173,19 +173,21 @@ nvinfer1::INetworkDefinition *Yolo::createYoloNetwork (
         else if (m_configBlocks.at(i).at("type") == "yolo")
         {
             nvinfer1::Dims prevTensorDims = previous->getDimensions();
-            assert(prevTensorDims.d[1] == prevTensorDims.d[2]);
+//            assert(prevTensorDims.d[1] == prevTensorDims.d[2]);
             TensorInfo& curYoloTensor = m_OutputTensors.at(outputTensorCount);
-            curYoloTensor.gridSize = prevTensorDims.d[1];
-            curYoloTensor.stride = m_InputW / curYoloTensor.gridSize;
-            m_OutputTensors.at(outputTensorCount).volume = curYoloTensor.gridSize
-                * curYoloTensor.gridSize
+            curYoloTensor.gridSizeY = prevTensorDims.d[1];
+            curYoloTensor.gridSizeX = prevTensorDims.d[2];
+            curYoloTensor.stride = m_InputH / curYoloTensor.gridSizeY;
+            m_OutputTensors.at(outputTensorCount).volume = curYoloTensor.gridSizeY
+                * curYoloTensor.gridSizeX
                 * (curYoloTensor.numBBoxes * (5 + curYoloTensor.numClasses));
             std::string layerName = "yolo_" + std::to_string(i);
             curYoloTensor.blobName = layerName;
             nvinfer1::IPluginV2* yoloPlugin
                 = new YoloLayerV3(m_OutputTensors.at(outputTensorCount).numBBoxes,
                                   m_OutputTensors.at(outputTensorCount).numClasses,
-                                  m_OutputTensors.at(outputTensorCount).gridSize);
+                                  m_OutputTensors.at(outputTensorCount).gridSizeX,
+                                  m_OutputTensors.at(outputTensorCount).gridSizeY);
             assert(yoloPlugin != nullptr);
             nvinfer1::IPluginV2Layer* yolo = network->addPluginV2(&previous, 1, *yoloPlugin);
             assert(yolo != nullptr);
@@ -206,10 +208,11 @@ nvinfer1::INetworkDefinition *Yolo::createYoloNetwork (
             nvinfer1::Dims prevTensorDims = previous->getDimensions();
             assert(prevTensorDims.d[1] == prevTensorDims.d[2]);
             TensorInfo& curRegionTensor = m_OutputTensors.at(outputTensorCount);
-            curRegionTensor.gridSize = prevTensorDims.d[1];
-            curRegionTensor.stride = m_InputW / curRegionTensor.gridSize;
-            m_OutputTensors.at(outputTensorCount).volume = curRegionTensor.gridSize
-                * curRegionTensor.gridSize
+            curRegionTensor.gridSizeY = prevTensorDims.d[1];
+            curRegionTensor.gridSizeX = prevTensorDims.d[2];
+            curRegionTensor.stride = m_InputW / curRegionTensor.gridSizeX;
+            m_OutputTensors.at(outputTensorCount).volume = curRegionTensor.gridSizeX
+                * curRegionTensor.gridSizeY
                 * (curRegionTensor.numBBoxes * (5 + curRegionTensor.numClasses));
             std::string layerName = "region_" + std::to_string(i);
             curRegionTensor.blobName = layerName;
@@ -423,7 +426,7 @@ void Yolo::parseConfigBlocks()
             m_InputH = std::stoul(block.at("height"));
             m_InputW = std::stoul(block.at("width"));
             m_InputC = std::stoul(block.at("channels"));
-            assert(m_InputW == m_InputH);
+//            assert(m_InputW == m_InputH);
             m_InputSize = m_InputC * m_InputH * m_InputW;
         }
         else if ((block.at("type") == "region") || (block.at("type") == "yolo"))
diff --git a/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/yolo.h b/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/yolo.h
index 968ba2b..f002588 100755
--- a/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/yolo.h
+++ b/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/yolo.h
@@ -50,7 +50,8 @@ struct TensorInfo
 {
     std::string blobName;
     uint stride{0};
-    uint gridSize{0};
+    uint gridSizeY{0};
+    uint gridSizeX{0};
     uint numClasses{0};
     uint numBBoxes{0};
     uint64_t volume{0};
diff --git a/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/yoloPlugins.cpp b/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/yoloPlugins.cpp
index e8a90b3..1da010c 100755
--- a/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/yoloPlugins.cpp
+++ b/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/yoloPlugins.cpp
@@ -45,7 +45,7 @@ void read(const char*& buffer, T& val)
 // Forward declaration of cuda kernels
 cudaError_t cudaYoloLayerV3 (
     const void* input, void* output, const uint& batchSize,
-    const uint& gridSize, const uint& numOutputClasses,
+    const uint& gridSizeX, const uint& gridSizeY,  const uint& numOutputClasses,
     const uint& numBBoxes, uint64_t outputSize, cudaStream_t stream);
 
 YoloLayerV3::YoloLayerV3 (const void* data, size_t length)
@@ -53,20 +53,23 @@ YoloLayerV3::YoloLayerV3 (const void* data, size_t length)
     const char *d = static_cast<const char*>(data);
     read(d, m_NumBoxes);
     read(d, m_NumClasses);
-    read(d, m_GridSize);
+    read(d, m_GridSizeX);
+    read(d, m_GridSizeY);
     read(d, m_OutputSize);
 };
 
 YoloLayerV3::YoloLayerV3 (
-    const uint& numBoxes, const uint& numClasses, const uint& gridSize) :
+    const uint& numBoxes, const uint& numClasses, const uint& gridSizeX, const uint& gridSizeY) :
     m_NumBoxes(numBoxes),
     m_NumClasses(numClasses),
-    m_GridSize(gridSize)
+    m_GridSizeX(gridSizeX),
+    m_GridSizeY(gridSizeY)
 {
     assert(m_NumBoxes > 0);
     assert(m_NumClasses > 0);
-    assert(m_GridSize > 0);
-    m_OutputSize = m_GridSize * m_GridSize * (m_NumBoxes * (4 + 1 + m_NumClasses));
+    assert(m_GridSizeX > 0);
+    assert(m_GridSizeY > 0);
+    m_OutputSize = m_GridSizeX * m_GridSizeY * (m_NumBoxes * (4 + 1 + m_NumClasses));
 };
 
 nvinfer1::Dims
@@ -100,14 +103,14 @@ int YoloLayerV3::enqueue(
     cudaStream_t stream)
 {
     CHECK(cudaYoloLayerV3(
-              inputs[0], outputs[0], batchSize, m_GridSize, m_NumClasses, m_NumBoxes,
+              inputs[0], outputs[0], batchSize, m_GridSizeX, m_GridSizeY, m_NumClasses, m_NumBoxes,
               m_OutputSize, stream));
     return 0;
 }
 
 size_t YoloLayerV3::getSerializationSize() const
 {
-    return sizeof(m_NumBoxes) + sizeof(m_NumClasses) + sizeof(m_GridSize) + sizeof(m_OutputSize);
+    return sizeof(m_NumBoxes) + sizeof(m_NumClasses) + sizeof(m_GridSizeX) + sizeof(m_GridSizeY) + sizeof(m_OutputSize);
 }
 
 void YoloLayerV3::serialize(void* buffer) const
@@ -115,13 +118,14 @@ void YoloLayerV3::serialize(void* buffer) const
     char *d = static_cast<char*>(buffer);
     write(d, m_NumBoxes);
     write(d, m_NumClasses);
-    write(d, m_GridSize);
+    write(d, m_GridSizeX);
+    write(d, m_GridSizeY);
     write(d, m_OutputSize);
 }
 
 nvinfer1::IPluginV2* YoloLayerV3::clone() const
 {
-    return new YoloLayerV3 (m_NumBoxes, m_NumClasses, m_GridSize);
+    return new YoloLayerV3 (m_NumBoxes, m_NumClasses, m_GridSizeX, m_GridSizeY);
 }
 
 REGISTER_TENSORRT_PLUGIN(YoloLayerV3PluginCreator);
diff --git a/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/yoloPlugins.h b/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/yoloPlugins.h
index f10047e..21487b2 100755
--- a/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/yoloPlugins.h
+++ b/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/yoloPlugins.h
@@ -51,7 +51,7 @@ class YoloLayerV3 : public nvinfer1::IPluginV2
 {
 public:
     YoloLayerV3 (const void* data, size_t length);
-    YoloLayerV3 (const uint& numBoxes, const uint& numClasses, const uint& gridSize);
+    YoloLayerV3 (const uint& numBoxes, const uint& numClasses, const uint& gridSizeX, const uint& gridSizeY);
     const char* getPluginType () const override { return YOLOV3LAYER_PLUGIN_NAME; }
     const char* getPluginVersion () const override { return YOLOV3LAYER_PLUGIN_VERSION; }
     int getNbOutputs () const override { return 1; }
@@ -89,7 +89,8 @@ public:
 private:
     uint m_NumBoxes {0};
     uint m_NumClasses {0};
-    uint m_GridSize {0};
+    uint m_GridSizeX {0};
+    uint m_GridSizeY {0};
     uint64_t m_OutputSize {0};
     std::string m_Namespace {""};
 };

Hi eh-steve,

Appreciate your sharing and contribution!

Hello eh-steve,

First, thanks for the patch.

I managed to run it with yoloV2 fine, including large batch size, but not yoloV3 - I keep getting errors from the TensorRT engine, even before parsing any results (just building the device buffers and submitting them to the engine).

Q1: Might it be that some other fix is somehow forgotten in your answer ?
Q2: Are you using batch > 1 for the yolov3 ?

Thanks again for your help, really appreciate this !

In my tests, this patch decrease detection accuracy. It’s normal?

Hi dannykario,

I have successfully run yoloV3_tiny with a batch size of 30, but successfully building the engine file for this batch size does depend on your available GPU memory, so I don’t think this is related to the patch?

Hi marcoslicianops,

I’ve noticed a general reduction in confidence values when benchmarking the TensorRT implementation of Yolo compared to alexeyAB’s darknet implementation (on a range of different models). Lowering the confidence threshold seemed to bring the recall back up back in line with darknet. Also, the anchor box indices used depend on whether you’re using AlexeyAB or PJReddie trained models - if the boxes look the wrong size, it’s possible you just need to use different anchors.

What differences are you seeing exactly before/after this patch?

For anyone looking for a Deepstream 5.0 patch (though I was kinda hoping Nvidia would have merged it upstream by now):

diff --git a/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/kernels.cu b/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/kernels.cu
index 45032f0..43f1906 100755
--- a/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/kernels.cu
+++ b/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/kernels.cu
@@ -17,20 +17,20 @@
 
 inline __device__ float sigmoidGPU(const float& x) { return 1.0f / (1.0f + __expf(-x)); }
 
-__global__ void gpuYoloLayerV3(const float* input, float* output, const uint gridSize, const uint numOutputClasses,
+__global__ void gpuYoloLayerV3(const float* input, float* output, const uint gridSizeX,  const uint gridSizeY, const uint numOutputClasses,
                                const uint numBBoxes)
 {
     uint x_id = blockIdx.x * blockDim.x + threadIdx.x;
     uint y_id = blockIdx.y * blockDim.y + threadIdx.y;
     uint z_id = blockIdx.z * blockDim.z + threadIdx.z;
 
-    if ((x_id >= gridSize) || (y_id >= gridSize) || (z_id >= numBBoxes))
+    if ((x_id >= gridSizeX) || (y_id >= gridSizeY) || (z_id >= numBBoxes))
     {
         return;
     }
 
-    const int numGridCells = gridSize * gridSize;
-    const int bbindex = y_id * gridSize + x_id;
+    const int numGridCells = gridSizeX * gridSizeY;
+    const int bbindex = y_id * gridSizeX + x_id;
 
     output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]
         = sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]);
@@ -54,23 +54,23 @@ __global__ void gpuYoloLayerV3(const float* input, float* output, const uint gri
     }
 }
 
-cudaError_t cudaYoloLayerV3(const void* input, void* output, const uint& batchSize, const uint& gridSize,
+cudaError_t cudaYoloLayerV3(const void* input, void* output, const uint& batchSize, const uint& gridSizeX, const uint& gridSizeY,
                             const uint& numOutputClasses, const uint& numBBoxes,
                             uint64_t outputSize, cudaStream_t stream);
 
-cudaError_t cudaYoloLayerV3(const void* input, void* output, const uint& batchSize, const uint& gridSize,
+cudaError_t cudaYoloLayerV3(const void* input, void* output, const uint& batchSize, const uint& gridSizeX, const uint& gridSizeY,
                             const uint& numOutputClasses, const uint& numBBoxes,
                             uint64_t outputSize, cudaStream_t stream)
 {
     dim3 threads_per_block(16, 16, 4);
-    dim3 number_of_blocks((gridSize / threads_per_block.x) + 1,
-                          (gridSize / threads_per_block.y) + 1,
+    dim3 number_of_blocks((gridSizeX / threads_per_block.x) + 1,
+                          (gridSizeY / threads_per_block.y) + 1,
                           (numBBoxes / threads_per_block.z) + 1);
     for (unsigned int batch = 0; batch < batchSize; ++batch)
     {
         gpuYoloLayerV3<<<number_of_blocks, threads_per_block, 0, stream>>>(
             reinterpret_cast<const float*>(input) + (batch * outputSize),
-            reinterpret_cast<float*>(output) + (batch * outputSize), gridSize, numOutputClasses,
+            reinterpret_cast<float*>(output) + (batch * outputSize), gridSizeX, gridSizeY, numOutputClasses,
             numBBoxes);
     }
     return cudaGetLastError();
diff --git a/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/nvdsparsebbox_Yolo.cpp b/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/nvdsparsebbox_Yolo.cpp
index 4226027..45399ef 100755
--- a/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/nvdsparsebbox_Yolo.cpp
+++ b/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/nvdsparsebbox_Yolo.cpp
@@ -29,7 +29,7 @@
 #include <iostream>
 #include <unordered_map>
 
-static const int NUM_CLASSES_YOLO = 80;
+static const int NUM_CLASSES_YOLO = 13;
 
 extern "C" bool NvDsInferParseCustomYoloV3(
     std::vector<NvDsInferLayerInfo> const& outputLayersInfo,
@@ -167,22 +167,22 @@ nmsAllClasses(const float nmsThresh,
 static std::vector<NvDsInferParseObjectInfo>
 decodeYoloV2Tensor(
     const float* detections, const std::vector<float> &anchors,
-    const uint gridSize, const uint stride, const uint numBBoxes,
+    const uint gridSizeX, const uint gridSizeY, const uint stride, const uint numBBoxes,
     const uint numOutputClasses, const float probThresh, const uint& netW,
     const uint& netH)
 {
     std::vector<NvDsInferParseObjectInfo> binfo;
-    for (uint y = 0; y < gridSize; ++y)
+    for (uint y = 0; y < gridSizeY; ++y)
     {
-        for (uint x = 0; x < gridSize; ++x)
+        for (uint x = 0; x < gridSizeX; ++x)
         {
             for (uint b = 0; b < numBBoxes; ++b)
             {
                 const float pw = anchors[b * 2];
                 const float ph = anchors[b * 2 + 1];
 
-                const int numGridCells = gridSize * gridSize;
-                const int bbindex = y * gridSize + x;
+                const int numGridCells = gridSizeX * gridSizeY;
+                const int bbindex = y * gridSizeX + x;
                 const float bx
                     = x + detections[bbindex + numGridCells * (b * (5 + numOutputClasses) + 0)];
                 const float by
@@ -226,22 +226,22 @@ decodeYoloV2Tensor(
 static std::vector<NvDsInferParseObjectInfo>
 decodeYoloV3Tensor(
     const float* detections, const std::vector<int> &mask, const std::vector<float> &anchors,
-    const uint gridSize, const uint stride, const uint numBBoxes,
+    const uint gridSizeX, const uint gridSizeY, const uint stride, const uint numBBoxes,
     const uint numOutputClasses, const float probThresh, const uint& netW,
     const uint& netH)
 {
     std::vector<NvDsInferParseObjectInfo> binfo;
-    for (uint y = 0; y < gridSize; ++y)
+    for (uint y = 0; y < gridSizeY; ++y)
     {
-        for (uint x = 0; x < gridSize; ++x)
+        for (uint x = 0; x < gridSizeX; ++x)
         {
             for (uint b = 0; b < numBBoxes; ++b)
             {
                 const float pw = anchors[mask[b] * 2];
                 const float ph = anchors[mask[b] * 2 + 1];
 
-                const int numGridCells = gridSize * gridSize;
-                const int bbindex = y * gridSize + x;
+                const int numGridCells = gridSizeX * gridSizeY;
+                const int bbindex = y * gridSizeX + x;
                 const float bx
                     = x + detections[bbindex + numGridCells * (b * (5 + numOutputClasses) + 0)];
                 const float by
@@ -304,8 +304,8 @@ static bool NvDsInferParseYoloV3(
     const std::vector<std::vector<int>> &masks)
 {
     const uint kNUM_BBOXES = 3;
-    static const float kNMS_THRESH = 0.3f;
-    static const float kPROB_THRESH = 0.7f;
+    static const float kNMS_THRESH = 0.4f;
+    static const float kPROB_THRESH = 0.3f;
 
     const std::vector<const NvDsInferLayerInfo*> sortedLayers =
         SortLayers (outputLayersInfo);
@@ -328,11 +328,12 @@ static bool NvDsInferParseYoloV3(
     for (uint idx = 0; idx < masks.size(); ++idx) {
         const NvDsInferLayerInfo &layer = *sortedLayers[idx]; // 255 x Grid x Grid
         assert (layer.dims.numDims == 3);
-        const uint gridSize = layer.dims.d[1];
-        const uint stride = networkInfo.width / gridSize;
+      const uint gridSizeY = layer.dims.d[1];
+      const uint gridSizeX = layer.dims.d[2];
+      const uint stride = networkInfo.height / gridSizeY;
 
         std::vector<NvDsInferParseObjectInfo> outObjs =
-            decodeYoloV3Tensor((const float*)(layer.buffer), masks[idx], anchors, gridSize, stride, kNUM_BBOXES,
+            decodeYoloV3Tensor((const float*)(layer.buffer), masks[idx], anchors, gridSizeX, gridSizeY, stride, kNUM_BBOXES,
                        NUM_CLASSES_YOLO, kPROB_THRESH, networkInfo.width, networkInfo.height);
         objects.insert(objects.end(), outObjs.begin(), outObjs.end());
     }
@@ -373,8 +374,8 @@ extern "C" bool NvDsInferParseCustomYoloV3Tiny(
         10, 14, 23, 27, 37, 58, 81, 82, 135, 169, 344, 319};
     static const std::vector<std::vector<int>> kMASKS = {
         {3, 4, 5},
-        //{0, 1, 2}}; // as per output result, select {1,2,3}
-        {1, 2, 3}};
+        {0, 1, 2}}; // as per output result, select {1,2,3}
+//        {1, 2, 3}};
 
     return NvDsInferParseYoloV3 (
         outputLayersInfo, networkInfo, detectionParams, objectList,
@@ -408,10 +409,12 @@ static bool NvDsInferParseYoloV2(
     }
 
     assert (layer.dims.numDims == 3);
-    const uint gridSize = layer.dims.d[1];
-    const uint stride = networkInfo.width / gridSize;
+    const uint gridSizeY = layer.dims.d[1];
+    const uint gridSizeX = layer.dims.d[2];
+    const uint stride = networkInfo.height / gridSizeY;
+
     std::vector<NvDsInferParseObjectInfo> objects =
-        decodeYoloV2Tensor((const float*)(layer.buffer), kANCHORS, gridSize, stride, kNUM_BBOXES,
+        decodeYoloV2Tensor((const float*)(layer.buffer), kANCHORS, gridSizeX, gridSizeY, stride, kNUM_BBOXES,
                    NUM_CLASSES_YOLO, probthreshold, networkInfo.width, networkInfo.height);
 
     objectList.clear();
diff --git a/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/trt_utils.cpp b/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/trt_utils.cpp
index 17049e8..46ee475 100755
--- a/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/trt_utils.cpp
+++ b/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/trt_utils.cpp
@@ -373,19 +373,20 @@ nvinfer1::ILayer* netAddUpsample(int layerIdx, std::map<std::string, std::string
     assert(block.at("type") == "upsample");
     nvinfer1::Dims inpDims = input->getDimensions();
     assert(inpDims.nbDims == 3);
-    assert(inpDims.d[1] == inpDims.d[2]);
+//    assert(inpDims.d[1] == inpDims.d[2]);
     int h = inpDims.d[1];
     int w = inpDims.d[2];
     int stride = std::stoi(block.at("stride"));
     // add pre multiply matrix as a constant
     nvinfer1::Dims preDims{3,
-                           {1, stride * h, w},
-                           {nvinfer1::DimensionType::kCHANNEL, nvinfer1::DimensionType::kSPATIAL,
+                           {1, stride * h, h},
+                           {nvinfer1::DimensionType::kCHANNEL,
+                            nvinfer1::DimensionType::kSPATIAL,
                             nvinfer1::DimensionType::kSPATIAL}};
-    int size = stride * h * w;
+    int size = stride * h * h;
     nvinfer1::Weights preMul{nvinfer1::DataType::kFLOAT, nullptr, size};
     float* preWt = new float[size];
-    /* (2*h * w)
+    /* (2*h * h)
     [ [1, 0, ..., 0],
       [1, 0, ..., 0],
       [0, 1, ..., 0],
@@ -397,12 +398,9 @@ nvinfer1::ILayer* netAddUpsample(int layerIdx, std::map<std::string, std::string
     */
     for (int i = 0, idx = 0; i < h; ++i)
     {
-        for (int s = 0; s < stride; ++s)
+        for (int j = 0; j < h * stride; ++j, ++idx)
         {
-            for (int j = 0; j < w; ++j, ++idx)
-            {
-                preWt[idx] = (i == j) ? 1.0 : 0.0;
-            }
+            preWt[idx] = (i == j) ? 1.0 : 0.0;
         }
     }
     preMul.values = preWt;
@@ -413,20 +411,20 @@ nvinfer1::ILayer* netAddUpsample(int layerIdx, std::map<std::string, std::string
     preM->setName(preLayerName.c_str());
     // add post multiply matrix as a constant
     nvinfer1::Dims postDims{3,
-                            {1, h, stride * w},
+                            {1, w, stride * w},
                             {nvinfer1::DimensionType::kCHANNEL, nvinfer1::DimensionType::kSPATIAL,
                              nvinfer1::DimensionType::kSPATIAL}};
-    size = stride * h * w;
+    size = stride * w * w;
     nvinfer1::Weights postMul{nvinfer1::DataType::kFLOAT, nullptr, size};
     float* postWt = new float[size];
-    /* (h * 2*w)
+    /* (w * 2*w)
     [ [1, 1, 0, 0, ..., 0, 0],
       [0, 0, 1, 1, ..., 0, 0],
       ...,
       ...,
       [0, 0, 0, 0, ..., 1, 1] ]
     */
-    for (int i = 0, idx = 0; i < h; ++i)
+    for (int i = 0, idx = 0; i < w; ++i)
     {
         for (int j = 0; j < stride * w; ++j, ++idx)
         {
@@ -441,8 +439,8 @@ nvinfer1::ILayer* netAddUpsample(int layerIdx, std::map<std::string, std::string
     post_m->setName(postLayerName.c_str());
     // add matrix multiply layers for upsampling
     nvinfer1::IMatrixMultiplyLayer* mm1
-        = network->addMatrixMultiply(*preM->getOutput(0), nvinfer1::MatrixOperation::kNONE, *input,
-                                     nvinfer1::MatrixOperation::kNONE);
+        = network->addMatrixMultiply(*preM->getOutput(0), nvinfer1::MatrixOperation::kNONE,
+                                     *input,nvinfer1::MatrixOperation::kNONE);
     assert(mm1 != nullptr);
     std::string mm1LayerName = "mm1_" + std::to_string(layerIdx);
     mm1->setName(mm1LayerName.c_str());
diff --git a/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/trt_utils.h b/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/trt_utils.h
index 97dcc5f..26e901b 100755
--- a/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/trt_utils.h
+++ b/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/trt_utils.h
@@ -46,23 +46,26 @@ private:
                              nvinfer1::DimsHW stride, nvinfer1::DimsHW padding,
                              nvinfer1::DimsHW dilation, const char* layerName) const override
     {
-        assert(inputDims.d[0] == inputDims.d[1]);
+//        assert(inputDims.d[0] == inputDims.d[1]);
         assert(kernelSize.d[0] == kernelSize.d[1]);
         assert(stride.d[0] == stride.d[1]);
         assert(padding.d[0] == padding.d[1]);
 
-        int outputDim;
+        int outputDimH;
+        int outputDimW;
         // Only layer maxpool_12 makes use of same padding
         if (m_SamePaddingLayers.find(layerName) != m_SamePaddingLayers.end())
         {
-            outputDim = (inputDims.d[0] + 2 * padding.d[0]) / stride.d[0];
+            outputDimH = (inputDims.d[0] + 2 * padding.d[0]) / stride.d[0];
+            outputDimW = (inputDims.d[1] + 2 * padding.d[1]) / stride.d[1];
         }
         // Valid Padding
         else
         {
-            outputDim = (inputDims.d[0] - kernelSize.d[0]) / stride.d[0] + 1;
+            outputDimH = (inputDims.d[0] - kernelSize.d[0]) / stride.d[0] + 1;
+            outputDimW = (inputDims.d[1] - kernelSize.d[1]) / stride.d[1] + 1;
         }
-        return nvinfer1::DimsHW{outputDim, outputDim};
+        return nvinfer1::DimsHW{outputDimH, outputDimW};
     }
 
 public:
diff --git a/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/yolo.cpp b/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/yolo.cpp
index 379694a..bbc34eb 100755
--- a/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/yolo.cpp
+++ b/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/yolo.cpp
@@ -173,19 +173,21 @@ nvinfer1::INetworkDefinition *Yolo::createYoloNetwork (
         else if (m_configBlocks.at(i).at("type") == "yolo")
         {
             nvinfer1::Dims prevTensorDims = previous->getDimensions();
-            assert(prevTensorDims.d[1] == prevTensorDims.d[2]);
+//            assert(prevTensorDims.d[1] == prevTensorDims.d[2]);
             TensorInfo& curYoloTensor = m_OutputTensors.at(outputTensorCount);
-            curYoloTensor.gridSize = prevTensorDims.d[1];
-            curYoloTensor.stride = m_InputW / curYoloTensor.gridSize;
-            m_OutputTensors.at(outputTensorCount).volume = curYoloTensor.gridSize
-                * curYoloTensor.gridSize
+            curYoloTensor.gridSizeY = prevTensorDims.d[1];
+            curYoloTensor.gridSizeX = prevTensorDims.d[2];
+            curYoloTensor.stride = m_InputH / curYoloTensor.gridSizeY;
+            m_OutputTensors.at(outputTensorCount).volume = curYoloTensor.gridSizeY
+                * curYoloTensor.gridSizeX
                 * (curYoloTensor.numBBoxes * (5 + curYoloTensor.numClasses));
             std::string layerName = "yolo_" + std::to_string(i);
             curYoloTensor.blobName = layerName;
             nvinfer1::IPluginV2* yoloPlugin
                 = new YoloLayerV3(m_OutputTensors.at(outputTensorCount).numBBoxes,
                                   m_OutputTensors.at(outputTensorCount).numClasses,
-                                  m_OutputTensors.at(outputTensorCount).gridSize);
+                                  m_OutputTensors.at(outputTensorCount).gridSizeX,
+                                  m_OutputTensors.at(outputTensorCount).gridSizeY);
             assert(yoloPlugin != nullptr);
             nvinfer1::IPluginV2Layer* yolo = network->addPluginV2(&previous, 1, *yoloPlugin);
             assert(yolo != nullptr);
@@ -206,10 +208,11 @@ nvinfer1::INetworkDefinition *Yolo::createYoloNetwork (
             nvinfer1::Dims prevTensorDims = previous->getDimensions();
             assert(prevTensorDims.d[1] == prevTensorDims.d[2]);
             TensorInfo& curRegionTensor = m_OutputTensors.at(outputTensorCount);
-            curRegionTensor.gridSize = prevTensorDims.d[1];
-            curRegionTensor.stride = m_InputW / curRegionTensor.gridSize;
-            m_OutputTensors.at(outputTensorCount).volume = curRegionTensor.gridSize
-                * curRegionTensor.gridSize
+            curRegionTensor.gridSizeY = prevTensorDims.d[1];
+            curRegionTensor.gridSizeX = prevTensorDims.d[2];
+            curRegionTensor.stride = m_InputW / curRegionTensor.gridSizeX;
+            m_OutputTensors.at(outputTensorCount).volume = curRegionTensor.gridSizeX
+                * curRegionTensor.gridSizeY
                 * (curRegionTensor.numBBoxes * (5 + curRegionTensor.numClasses));
             std::string layerName = "region_" + std::to_string(i);
             curRegionTensor.blobName = layerName;
@@ -423,7 +426,7 @@ void Yolo::parseConfigBlocks()
             m_InputH = std::stoul(block.at("height"));
             m_InputW = std::stoul(block.at("width"));
             m_InputC = std::stoul(block.at("channels"));
-            assert(m_InputW == m_InputH);
+//            assert(m_InputW == m_InputH);
             m_InputSize = m_InputC * m_InputH * m_InputW;
         }
         else if ((block.at("type") == "region") || (block.at("type") == "yolo"))
diff --git a/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/yolo.h b/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/yolo.h
index 968ba2b..f002588 100755
--- a/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/yolo.h
+++ b/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/yolo.h
@@ -50,7 +50,8 @@ struct TensorInfo
 {
     std::string blobName;
     uint stride{0};
-    uint gridSize{0};
+    uint gridSizeY{0};
+    uint gridSizeX{0};
     uint numClasses{0};
     uint numBBoxes{0};
     uint64_t volume{0};
diff --git a/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/yoloPlugins.cpp b/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/yoloPlugins.cpp
index e8a90b3..1da010c 100755
--- a/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/yoloPlugins.cpp
+++ b/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/yoloPlugins.cpp
@@ -45,7 +45,7 @@ void read(const char*& buffer, T& val)
 // Forward declaration of cuda kernels
 cudaError_t cudaYoloLayerV3 (
     const void* input, void* output, const uint& batchSize,
-    const uint& gridSize, const uint& numOutputClasses,
+    const uint& gridSizeX, const uint& gridSizeY,  const uint& numOutputClasses,
     const uint& numBBoxes, uint64_t outputSize, cudaStream_t stream);
 
 YoloLayerV3::YoloLayerV3 (const void* data, size_t length)
@@ -53,20 +53,23 @@ YoloLayerV3::YoloLayerV3 (const void* data, size_t length)
     const char *d = static_cast<const char*>(data);
     read(d, m_NumBoxes);
     read(d, m_NumClasses);
-    read(d, m_GridSize);
+    read(d, m_GridSizeX);
+    read(d, m_GridSizeY);
     read(d, m_OutputSize);
 };
 
 YoloLayerV3::YoloLayerV3 (
-    const uint& numBoxes, const uint& numClasses, const uint& gridSize) :
+    const uint& numBoxes, const uint& numClasses, const uint& gridSizeX, const uint& gridSizeY) :
     m_NumBoxes(numBoxes),
     m_NumClasses(numClasses),
-    m_GridSize(gridSize)
+    m_GridSizeX(gridSizeX),
+    m_GridSizeY(gridSizeY)
 {
     assert(m_NumBoxes > 0);
     assert(m_NumClasses > 0);
-    assert(m_GridSize > 0);
-    m_OutputSize = m_GridSize * m_GridSize * (m_NumBoxes * (4 + 1 + m_NumClasses));
+    assert(m_GridSizeX > 0);
+    assert(m_GridSizeY > 0);
+    m_OutputSize = m_GridSizeX * m_GridSizeY * (m_NumBoxes * (4 + 1 + m_NumClasses));
 };
 
 nvinfer1::Dims
@@ -100,14 +103,14 @@ int YoloLayerV3::enqueue(
     cudaStream_t stream)
 {
     CHECK(cudaYoloLayerV3(
-              inputs[0], outputs[0], batchSize, m_GridSize, m_NumClasses, m_NumBoxes,
+              inputs[0], outputs[0], batchSize, m_GridSizeX, m_GridSizeY, m_NumClasses, m_NumBoxes,
               m_OutputSize, stream));
     return 0;
 }
 
 size_t YoloLayerV3::getSerializationSize() const
 {
-    return sizeof(m_NumBoxes) + sizeof(m_NumClasses) + sizeof(m_GridSize) + sizeof(m_OutputSize);
+    return sizeof(m_NumBoxes) + sizeof(m_NumClasses) + sizeof(m_GridSizeX) + sizeof(m_GridSizeY) + sizeof(m_OutputSize);
 }
 
 void YoloLayerV3::serialize(void* buffer) const
@@ -115,13 +118,14 @@ void YoloLayerV3::serialize(void* buffer) const
     char *d = static_cast<char*>(buffer);
     write(d, m_NumBoxes);
     write(d, m_NumClasses);
-    write(d, m_GridSize);
+    write(d, m_GridSizeX);
+    write(d, m_GridSizeY);
     write(d, m_OutputSize);
 }
 
 nvinfer1::IPluginV2* YoloLayerV3::clone() const
 {
-    return new YoloLayerV3 (m_NumBoxes, m_NumClasses, m_GridSize);
+    return new YoloLayerV3 (m_NumBoxes, m_NumClasses, m_GridSizeX, m_GridSizeY);
 }
 
 REGISTER_TENSORRT_PLUGIN(YoloLayerV3PluginCreator);
diff --git a/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/yoloPlugins.h b/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/yoloPlugins.h
index f10047e..21487b2 100755
--- a/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/yoloPlugins.h
+++ b/objectDetector_Yolo/nvdsinfer_custom_impl_Yolo/yoloPlugins.h
@@ -51,7 +51,7 @@ class YoloLayerV3 : public nvinfer1::IPluginV2
 {
 public:
     YoloLayerV3 (const void* data, size_t length);
-    YoloLayerV3 (const uint& numBoxes, const uint& numClasses, const uint& gridSize);
+    YoloLayerV3 (const uint& numBoxes, const uint& numClasses, const uint& gridSizeX, const uint& gridSizeY);
     const char* getPluginType () const override { return YOLOV3LAYER_PLUGIN_NAME; }
     const char* getPluginVersion () const override { return YOLOV3LAYER_PLUGIN_VERSION; }
     int getNbOutputs () const override { return 1; }
@@ -89,7 +89,8 @@ public:
 private:
     uint m_NumBoxes {0};
     uint m_NumClasses {0};
-    uint m_GridSize {0};
+    uint m_GridSizeX {0};
+    uint m_GridSizeY {0};
     uint64_t m_OutputSize {0};
     std::string m_Namespace {""};
 };

2 Likes

My objects were no longer recognized with non-square pre-treined model. When I switch to a trained square model, it works normally without this patch. If I’m not mistaken, even square models don’t work well with this patch.

Hi marcoslicianops,

Having tested with the original 416x416 PJReddie model, I can measure a drop in performance after this patch. Switching back to using the original anchors for this model results in some improvement, but not the same results as before the patch. I’ll dig into this when I get a chance, thanks for pointing it out!

1 Like