ONNX model with Jetson-Inference using GPU

Hi,

I retrained a model (ssd mobilenet v1) using Jetson-inference and pytorch (https://github.com/dusty-nv/jetson-inference/blob/master/docs/pytorch-ssd.md), and then I generated a ONNX file (for person detection).

But when trying to run this model with jetson.inference.detectNet in python (I made some change in the source code to use the GPU, with FP16 => working well with original ssd_mobilenet_v2_coco.uff), tensorRT doesn’t want to run the inferences with the ONNX model (I also tried INT8 and FP32 without success) :

[TRT] device GPU, completed writing engine cache to /usr/local/bin/networks/SSD-Mobilenet-v1-ONNX/ssd-mobilenet.onnx.1.0.7100.GPU.FP1 6.engine
[TRT] device GPU, loaded /usr/local/bin/networks/SSD-Mobilenet-v1-ONNX/ssd-mobilenet.onnx
[TRT] Deserialize required 123757 microseconds.
[TRT]
[TRT] CUDA engine context initialized on device GPU:
[TRT] – layers 97
[TRT] – maxBatchSize 1
[TRT] – workspace 0
[TRT] – deviceMemory 20092416
[TRT] – bindings 3
[TRT] binding 0
– index 0
– name ‘input_0’
– type FP32
– in/out INPUT
– # dims 4
– dim #0 1 (SPATIAL)
– dim #1 3 (SPATIAL)
– dim #2 300 (SPATIAL)
– dim #3 300 (SPATIAL)
[TRT] binding 1
– index 1
– name ‘scores’
– type FP32
– in/out OUTPUT
– # dims 3
– dim #0 1 (SPATIAL)
– dim #1 3000 (SPATIAL)
– dim #2 2 (SPATIAL)
[TRT] binding 2
– index 2
– name ‘boxes’
– type FP32
– in/out OUTPUT
– # dims 3
– dim #0 1 (SPATIAL)
– dim #1 3000 (SPATIAL)
– dim #2 4 (SPATIAL)
[TRT]
[TRT] INVALID_ARGUMENT: Cannot find binding of given name: Input
[TRT] failed to find requested input layer Input in network
[TRT] device GPU, failed to create resources for CUDA engine
[TRT] failed to create TensorRT engine for /usr/local/bin/networks/SSD-Mobilenet-v1-ONNX/ssd-mobilenet.onnx, device GPU
[TRT] detectNet – failed to initialize.

Any idea of what is wrong ? The model runs successfully with detectnet when using the cpp version but uses the CPU instead of GPU

detectnet --model=models/Person/ssd-mobilenet.onnx --labels=models/Person/labels.txt
–input-blob=input_0 --output-cvg=scores --output-bbox=boxes
“images/*.jpg” test_Person

Files available at :

Hmm it doesn’t seem to be recieving/parsing your custom command-line arguments.

When you run that command line in the terminal, are there line breaks? Can you try running it all on one line?

I didn’t run the python script, I implemented the function in another script with no args (values to pass are in the script).

I also made some changes in the library, I will upload the files soon.

Now I call the functions with :
labels = open(“jetson-inference/data/networks/SSD-Mobilenet-v1-ONNX/labels.txt”).readlines()
net = jetson.inference.detectNet(“ssd-mobilenet-v1-onnx”, threshold=0.7, precision=“FP16”, device=“GPU”, allowGPUFallback=True)

These are the changes I made in the library :

Changes in PyDetectNet.cpp :

// Init
static int PyDetectNet_Init( PyDetectNet_Object* self, PyObject *args, PyObject *kwds )
{
LogDebug(LOG_PY_INFERENCE “PyDetectNet_Init()\n”);

// parse arguments
PyObject* argList = NULL;
const char* network = “ssd-mobilenet-v2”;
float threshold = DETECTNET_DEFAULT_THRESHOLD;

const char* precision = “FP16”;
// precisionType PrecisionType=TYPE_FP32;

const char* device = “GPU”;
// deviceType DeviceType = DEVICE_GPU;

int allowGPUFallback = false;

static char* kwlist = {“network”, “threshold”, “precision”, “device”, “allowGPUFallback”, NULL};
// |sOf

if( !PyArg_ParseTupleAndKeywords(args, kwds, “|sfssp”, kwlist, &network, &threshold, &precision, &device, &allowGPUFallback))
{
PyErr_SetString(PyExc_Exception, LOG_PY_INFERENCE “detectNet.init() failed to parse args tuple”);
printf(“%s\n”, network);
printf(“%f\n”, threshold);
printf(“%s\n”, precision);
printf(“%s\n”, device);
// printf(“%b\n”, allowGPUFallback);
return -1;
}

LogVerbose(LOG_PY_INFERENCE “detectNet loading build-in network ‘%s’\n”, network);

// parse the selected built-in network
detectNet::NetworkType networkType = detectNet::NetworkTypeFromStr(network);

uint32_t maxBatchSize=DEFAULT_MAX_BATCH_SIZE;
precisionType precision_type = precisionTypeFromStr(precision);
deviceType device_type = deviceTypeFromStr(device);
// bool allowGPUFallback=true;

if( networkType == detectNet::CUSTOM )
{
PyErr_SetString(PyExc_Exception, LOG_PY_INFERENCE “detectNet invalid built-in network was requested”);
printf(LOG_PY_INFERENCE “detectNet invalid built-in network was requested (‘%s’)\n”, network);
return -1;
}

// load the built-in network
// self->net = detectNet::Create(networkType, threshold, maxBatchSize, precision_type, device_type, allowGPUFallback);
self->net = detectNet::Create(networkType, threshold, maxBatchSize, precision_type, device_type, allowGPUFallback);

// confirm the network loaded
if( !self->net )
{
PyErr_SetString(PyExc_Exception, LOG_PY_INFERENCE “detectNet failed to load network”);
LogError(LOG_PY_INFERENCE “detectNet failed to load network\n”);
return -1;
}

self->base.net = self->net;
return 0;
}

changes in detectNet.cpp :

detectNet* detectNet::Create( NetworkType networkType, float threshold, uint32_t maxBatchSize,
precisionType precision, deviceType device, bool allowGPUFallback )
{
#if 1
if( networkType == PEDNET_MULTI )
return Create(“networks/multiped-500/deploy.prototxt”, “networks/multiped-500/snapshot_iter_178000.caffemodel”, 117.0f, “networks/multiped-500/class_labels.txt”, threshold, DETECTNET_DEFAULT_INPUT, DETECTNET_DEFAULT_COVERAGE, DETECTNET_DEFAULT_BBOX, maxBatchSize, precision, device, allowGPUFallback );
else if( networkType == FACENET )
return Create(“networks/facenet-120/deploy.prototxt”, “networks/facenet-120/snapshot_iter_24000.caffemodel”, 0.0f, “networks/facenet-120/class_labels.txt”, threshold, DETECTNET_DEFAULT_INPUT, DETECTNET_DEFAULT_COVERAGE, DETECTNET_DEFAULT_BBOX, maxBatchSize, precision, device, allowGPUFallback );
else if( networkType == PEDNET )
return Create(“networks/ped-100/deploy.prototxt”, “networks/ped-100/snapshot_iter_70800.caffemodel”, 0.0f, “networks/ped-100/class_labels.txt”, threshold, DETECTNET_DEFAULT_INPUT, DETECTNET_DEFAULT_COVERAGE, DETECTNET_DEFAULT_BBOX, maxBatchSize, precision, device, allowGPUFallback );
else if( networkType == COCO_AIRPLANE )
return Create(“networks/DetectNet-COCO-Airplane/deploy.prototxt”, “networks/DetectNet-COCO-Airplane/snapshot_iter_22500.caffemodel”, 0.0f, “networks/DetectNet-COCO-Airplane/class_labels.txt”, threshold, DETECTNET_DEFAULT_INPUT, DETECTNET_DEFAULT_COVERAGE, DETECTNET_DEFAULT_BBOX, maxBatchSize, precision, device, allowGPUFallback );
else if( networkType == COCO_BOTTLE )
return Create(“networks/DetectNet-COCO-Bottle/deploy.prototxt”, “networks/DetectNet-COCO-Bottle/snapshot_iter_59700.caffemodel”, 0.0f, “networks/DetectNet-COCO-Bottle/class_labels.txt”, threshold, DETECTNET_DEFAULT_INPUT, DETECTNET_DEFAULT_COVERAGE, DETECTNET_DEFAULT_BBOX, maxBatchSize, precision, device, allowGPUFallback );
else if( networkType == COCO_CHAIR )
return Create(“networks/DetectNet-COCO-Chair/deploy.prototxt”, “networks/DetectNet-COCO-Chair/snapshot_iter_89500.caffemodel”, 0.0f, “networks/DetectNet-COCO-Chair/class_labels.txt”, threshold, DETECTNET_DEFAULT_INPUT, DETECTNET_DEFAULT_COVERAGE, DETECTNET_DEFAULT_BBOX, maxBatchSize, precision, device, allowGPUFallback );
else if( networkType == COCO_DOG )
return Create(“networks/DetectNet-COCO-Dog/deploy.prototxt”, “networks/DetectNet-COCO-Dog/snapshot_iter_38600.caffemodel”, 0.0f, “networks/DetectNet-COCO-Dog/class_labels.txt”, threshold, DETECTNET_DEFAULT_INPUT, DETECTNET_DEFAULT_COVERAGE, DETECTNET_DEFAULT_BBOX, maxBatchSize, precision, device, allowGPUFallback );
#if NV_TENSORRT_MAJOR > 4
else if( networkType == SSD_INCEPTION_V2 )
return Create(“networks/SSD-Inception-v2/ssd_inception_v2_coco.uff”, “networks/SSD-Inception-v2/ssd_coco_labels.txt”, threshold, “Input”, Dims3(3,300,300), “NMS”, “NMS_1”, maxBatchSize, precision, device, allowGPUFallback);
else if( networkType == SSD_MOBILENET_V1_ONNX )
return Create(“networks/SSD-Mobilenet-v1-ONNX/ssd-mobilenet.onnx”, “networks/SSD-Mobilenet-v1-ONNX/labels.txt”, threshold, “Input”, Dims3(3,300,300), “NMS”, “NMS_1”, maxBatchSize, precision, device, allowGPUFallback);
else if( networkType == SSD_MOBILENET_V1 )
return Create(“networks/SSD-Mobilenet-v1/ssd_mobilenet_v1_coco.uff”, “networks/SSD-Mobilenet-v1/ssd_coco_labels.txt”, threshold, “Input”, Dims3(3,300,300), “Postprocessor”, “Postprocessor_1”, maxBatchSize, precision, device, allowGPUFallback);
else if( networkType == SSD_MOBILENET_V2 )
return Create(“networks/SSD-Mobilenet-v2/ssd_mobilenet_v2_coco.uff”, “networks/SSD-Mobilenet-v2/ssd_coco_labels.txt”, threshold, “Input”, Dims3(3,300,300), “NMS”, “NMS_1”, maxBatchSize, precision, device, allowGPUFallback);
endif
else
return NULL;
else
if( networkType == PEDNET_MULTI )
return Create(“networks/multiped-500/deploy.prototxt”, “networks/multiped-500/snapshot_iter_178000.caffemodel”, “networks/multiped-500/mean.binaryproto”, threshold, DETECTNET_DEFAULT_INPUT, DETECTNET_DEFAULT_COVERAGE, DETECTNET_DEFAULT_BBOX, maxBatchSize, precision, device, allowGPUFallback );
else if( networkType == FACENET )
return Create(“networks/facenet-120/deploy.prototxt”, “networks/facenet-120/snapshot_iter_24000.caffemodel”, NULL, threshold, DETECTNET_DEFAULT_INPUT, DETECTNET_DEFAULT_COVERAGE, DETECTNET_DEFAULT_BBOX, maxBatchSize, precision, device, allowGPUFallback );
else if( networkType == PEDNET )
return Create(“networks/ped-100/deploy.prototxt”, “networks/ped-100/snapshot_iter_70800.caffemodel”, “networks/ped-100/mean.binaryproto”, threshold, DETECTNET_DEFAULT_INPUT, DETECTNET_DEFAULT_COVERAGE, DETECTNET_DEFAULT_BBOX, maxBatchSize, precision, device, allowGPUFallback );
else if( networkType == COCO_AIRPLANE )
return Create(“networks/DetectNet-COCO-Airplane/deploy.prototxt”, “networks/DetectNet-COCO-Airplane/snapshot_iter_22500.caffemodel”, “networks/DetectNet-COCO-Airplane/mean.binaryproto”, threshold, DETECTNET_DEFAULT_INPUT, DETECTNET_DEFAULT_COVERAGE, DETECTNET_DEFAULT_BBOX, maxBatchSize, precision, device, allowGPUFallback );
else if( networkType == COCO_BOTTLE )
return Create(“networks/DetectNet-COCO-Bottle/deploy.prototxt”, “networks/DetectNet-COCO-Bottle/snapshot_iter_59700.caffemodel”, “networks/DetectNet-COCO-Bottle/mean.binaryproto”, threshold, DETECTNET_DEFAULT_INPUT, DETECTNET_DEFAULT_COVERAGE, DETECTNET_DEFAULT_BBOX, maxBatchSize, precision, device, allowGPUFallback );
else if( networkType == COCO_CHAIR )
return Create(“networks/DetectNet-COCO-Chair/deploy.prototxt”, “networks/DetectNet-COCO-Chair/snapshot_iter_89500.caffemodel”, “networks/DetectNet-COCO-Chair/mean.binaryproto”, threshold, DETECTNET_DEFAULT_INPUT, DETECTNET_DEFAULT_COVERAGE, DETECTNET_DEFAULT_BBOX, maxBatchSize, precision, device, allowGPUFallback );
else if( networkType == COCO_DOG )
return Create(“networks/DetectNet-COCO-Dog/deploy.prototxt”, “networks/DetectNet-COCO-Dog/snapshot_iter_38600.caffemodel”, “networks/DetectNet-COCO-Dog/mean.binaryproto”, threshold, DETECTNET_DEFAULT_INPUT, DETECTNET_DEFAULT_COVERAGE, DETECTNET_DEFAULT_BBOX, maxBatchSize, precision, device, allowGPUFallback );
else
return NULL;
endif
}

// Create
detectNet* detectNet::Create( const commandLine& cmdLine )
{
detectNet* net = NULL;

// parse command line parameters
const char* modelName = cmdLine.GetString(“network”);

if( !modelName )
modelName = cmdLine.GetString(“model”, “ssd-mobilenet-v2”);

float threshold = cmdLine.GetFloat(“threshold”);

if( threshold == 0.0f )
threshold = DETECTNET_DEFAULT_THRESHOLD;

int maxBatchSize = cmdLine.GetInt(“batch_size”);

if( maxBatchSize < 1 )
maxBatchSize = DEFAULT_MAX_BATCH_SIZE;

const char* precisionName = cmdLine.GetString(“precision”);

if( !precisionName )
precisionName = cmdLine.GetString(“precision”, “FP16”);

// parse the model type
const precisionType type_precision = precisionTypeFromStr(precisionName);

const char* deviceName = cmdLine.GetString(“device”);

if( !deviceName )
deviceName = cmdLine.GetString(“device”, “GPU”);

// parse the device type
const deviceType type_device = deviceTypeFromStr(deviceName);

bool allowGPUFallback_value = cmdLine.GetBool(“allowGPUFallback”);

if( !allowGPUFallback_value )
allowGPUFallback_value = cmdLine.GetBool(“allowGPUFallback”, false);

// parse the model type
const detectNet::NetworkType type = NetworkTypeFromStr(modelName);

if( type == detectNet::CUSTOM )
{
const char* prototxt = cmdLine.GetString(“prototxt”);
const char* input = cmdLine.GetString(“input_blob”);
const char* out_blob = cmdLine.GetString(“output_blob”);
const char* out_cvg = cmdLine.GetString(“output_cvg”);
const char* out_bbox = cmdLine.GetString(“output_bbox”);
const char* class_labels = cmdLine.GetString(“class_labels”);

  if( !input ) 	
  	input = DETECTNET_DEFAULT_INPUT;

  if( !out_blob )
  {
  	if( !out_cvg )  out_cvg  = DETECTNET_DEFAULT_COVERAGE;
  	if( !out_bbox ) out_bbox = DETECTNET_DEFAULT_BBOX;
  }

  if( !class_labels )
  	class_labels = cmdLine.GetString("labels");

  float meanPixel = cmdLine.GetFloat("mean_pixel");

  net = detectNet::Create(prototxt, modelName, meanPixel, class_labels, threshold, input, 
  					out_blob ? NULL : out_cvg, out_blob ? out_blob : out_bbox, maxBatchSize);

}
else
{
// create detectNet from pretrained model
// net = detectNet::Create(type, threshold, maxBatchSize);
net = detectNet::Create(type, threshold, maxBatchSize, type_precision, type_device, allowGPUFallback_value);
}

if( !net )
return NULL;

// enable layer profiling if desired
if( cmdLine.GetFlag(“profile”) )
net->EnableLayerProfiler();

// set overlay alpha value
net->SetOverlayAlpha(cmdLine.GetFloat(“alpha”, DETECTNET_DEFAULT_ALPHA));

return net;
}

changes in detectNet.h :
Line 193

#if NV_TENSORRT_MAJOR > 4
SSD_MOBILENET_V1, /< SSD Mobilenet-v1 UFF model, trained on MS-COCO */
SSD_MOBILENET_V1_ONNX, /
< SSD Mobilenet-v1 ONNX model
SSD_MOBILENET_V2, /< SSD Mobilenet-v2 UFF model, trained on MS-COCO */
SSD_INCEPTION_V2 /
< SSD Inception-v2 UFF model, trained on MS-COCO */

changes in commandLine.h :

bool GetBool( const char* argName, bool defaultValue=false, bool allowOtherDelimiters=true ) const;

changes in commandLine.cpp :

// GetBool
bool commandLine::GetBool( const char* string_ref, bool default_value, bool allowOtherDelimiters ) const
{
if( argc < 1 )
return 0;

bool bFound = false;
bool value = false;

for( int i=ARGC_START; i < argc; i++ )
{
const int string_start = strFindDelimiter(‘-’, argv[i]);

  if( string_start == 0 )
  	continue;
  
  const char* string_argv = &argv[i][string_start];
  const int length = (int)strlen(string_ref);

  if (!strncasecmp(string_argv, string_ref, length))
  {
  	if (length+1 <= (int)strlen(string_argv))
  	{
  		int auto_inc = (string_argv[length] == '=') ? 1 : 0;
  		value = atoi(&string_argv[length + auto_inc]);
  	}
  	else
  	{
  		value = false;
  	}

  	bFound = true;
  	continue;
  }

}

if( bFound )
return value;

if( !allowOtherDelimiters )
return default_value;

// try looking for the argument with delimiters swapped
char* swapped_ref = strSwapDelimiter(string_ref);

if( !swapped_ref )
return default_value;

value = GetInt(swapped_ref, default_value, false);
free(swapped_ref);
return value;
}

The modified library is available here :
https://wetransfer.com/downloads/87b6e97a43a38d148baf14fec564b86620200722102800/c79f7246a5f2ce01087093db9611859f20200722102856/5ce4ba

To call the network :
net = jetson.inference.detectNet(“ssd-mobilenet-v1-onnx”, threshold=0.7, precision=“FP16”, device=“GPU”, allowGPUFallback=True)

@Pelepicier, I am unable to debug all the changes you made. I recommend going back to the original jetson-inference code and creating your model like this:

net = jetson.inference.detectNet(argv=['--model=my_model_path/ssd-mobilenet.onnx',
                                       '--labels=my_model_path/labels.txt',
                                       '--input-blob=input_0', '--output-cvg=scores', '--output-bbox=boxes',
                                       threshold=0.5)

This will use the parsing already in detectNet and should be working.

Hi @dusty_nv,

Thank you this is exactly what I needed.
Now I can get 250FPS with my custom retrained ONNX model with only “Person” label (thanks to your scripts in jetson-inference). I use GPU+DLA_0+DLA_1 with multiprocessing (in python), and just needed to changed 8 lines in c/detectNet.cpp to make it work (to pass the device and precision in args).
I’m trying to have a little bit more FPS, I was wondering : what does the batch_size parameter change in the engine ? It does not seems to have consequences on the inference time…

It only currently changes the max batch size that the TensorRT engine can support. It doesn’t actually do multi-image batching, as that would require additional pre/post-processing code and changes to the input streaming. I would recommend DeepStream for applications using multi-stream batching.

I miss a bracket when i use this. Are you missing this: ] ?

Yes you need to add “]” !

Hi @dusty_nv, I’m back on this after few months! Do you have any plan to implement multi-image batching as a module of anything like this? (ideally passing an array of images in detectnet would be perfect)

Also I successed doing multi images batching using Tensorrt with another python script (modified from object-detection-tensorrt-example/detect_objects_webcam.py at master · NVIDIA/object-detection-tensorrt-example · GitHub) but with bad performance (around 100fps with only the GPU, changing the batch size from 1 to 8 increases linearly the time of inference). How could I obtain the famous 850FPS shown in the benchmarks?

I don’t currently have plans to implement batching in jetson-inference, as the primary use-case is for single-stream applications and demos/examples. I would recommend DeepStream or the TRT samples you found for batching.

The benchmarks you can run from this GitHub, on Xavier they use INT8, GPU + 2xDLA, and batching: https://github.com/NVIDIA-AI-IOT/jetson_benchmarks

Hi @dusty_nv, sorry for highjacking this thread, but my problem looks similar.

As already reported in “your” github I’m having followed your tutorial for doing a re-training on SSD MobileNet V1 for 9 classes of fruits. This was successful, so I’m having now an ssd-mobilenet.onnx file and a label file and I wanted to integrate that into an existing infrastructure, which is more or less based upon a mix of some DeepStream Python samples https://github.com/NVIDIA-AI-IOT/deepstream_python_apps/tree/master/appshttps://github.com/NVIDIA-AI-IOT/deepstream_python_apps/tree/master/apps

There an XXX_pgie_config.txt file is used in order to configure the engine. This works fine for models like resnet10.caffemodel and resnet34_peoplenet_pruned.etlt. I was now trying to figure a similar configuration for the ssd-mobilenet.onnx to no avail currently. I know, my configuration must be incomplete ATM, but I need a hint, what to add or remove or change.

This is my current configuration, which is stitched together from various sources, e.g. Deep-Stream-ONNX/config_infer_custom_yolo.txt at master · thatbrguy/Deep-Stream-ONNX · GitHub

[property]
workspace-size=600
gpu-id=0
model-color-format=0
net-scale-factor=0.003921569790691137
onnx-file=/home/ubuntu/dragonfly-safety/jetson-inference/models/primary-detector-nano/ssd-mobilenet.onnx
labelfile-path=/home/ubuntu/dragonfly-safety/jetson-inference/models/primary-detector-nano/labels_onnx.txt
batch-size=3
network-mode=2
num-detected-classes=9
maintain-aspect-ratio=1
output-blob-names=grid
gie-unique-id=1
is-classifier=0

This quoted config also has a custom bbox detector. I don’t know, what to do with this, having dropped that for now.

The entire thing starts promising, but soon it finishes with an error:

Using winsys: x11 
0:00:00.499388128 12676      0x7c98440 INFO                 nvinfer gstnvinfer.cpp:619:gst_nvinfer_logger:<primary-inference> NvDsInferContext[UID 1]: Info from NvDsInferContextImpl::buildModel() <nvdsinfer_context_impl.cpp:1716> [UID = 1]: Trying to create engine from model files
----------------------------------------------------------------
Input filename:   /home/ubuntu/dragonfly-safety/jetson-inference/models/primary-detector-nano/ssd-mobilenet.onnx
ONNX IR version:  0.0.6
Opset version:    9
Producer name:    pytorch
Producer version: 1.6
Domain:           
Model version:    0
Doc string:       
----------------------------------------------------------------
WARNING: [TRT]: onnx2trt_utils.cpp:220: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.
WARNING: [TRT]: onnx2trt_utils.cpp:246: One or more weights outside the range of INT32 was clamped
WARNING: [TRT]: onnx2trt_utils.cpp:246: One or more weights outside the range of INT32 was clamped
WARNING: [TRT]: onnx2trt_utils.cpp:246: One or more weights outside the range of INT32 was clamped
INFO: [TRT]: Some tactics do not have sufficient workspace memory to run. Increasing workspace size may increase performance, please check verbose output.
ERROR: [TRT]: ../rtSafe/cuda/cutensorReformat.cpp (227) - Assertion Error in executeCutensor: 0 (validateInputsCutensor(src, dst))
ERROR: Build engine failed from config file
ERROR: failed to build trt engine.
0:01:59.530931154 12676      0x7c98440 ERROR                nvinfer gstnvinfer.cpp:613:gst_nvinfer_logger:<primary-inference> NvDsInferContext[UID 1]: Error in NvDsInferContextImpl::buildModel() <nvdsinfer_context_impl.cpp:1736> [UID = 1]: build engine file failed
0:01:59.534033449 12676      0x7c98440 ERROR                nvinfer gstnvinfer.cpp:613:gst_nvinfer_logger:<primary-inference> NvDsInferContext[UID 1]: Error in NvDsInferContextImpl::generateBackendContext() <nvdsinfer_context_impl.cpp:1822> [UID = 1]: build backend context failed
0:01:59.534300953 12676      0x7c98440 ERROR                nvinfer gstnvinfer.cpp:613:gst_nvinfer_logger:<primary-inference> NvDsInferContext[UID 1]: Error in NvDsInferContextImpl::initialize() <nvdsinfer_context_impl.cpp:1149> [UID = 1]: generate backend failed, check config file settings
0:01:59.545112136 12676      0x7c98440 WARN                 nvinfer gstnvinfer.cpp:812:gst_nvinfer_start:<primary-inference> error: Failed to create NvDsInferContext instance
0:01:59.545162033 12676      0x7c98440 WARN                 nvinfer gstnvinfer.cpp:812:gst_nvinfer_start:<primary-inference> error: Config file path: /tmp/tmpcm6h3n8z, NvDsInfer Error: NVDSINFER_CONFIG_FAILED
Error: gst-resource-error-quark: Failed to create NvDsInferContext instance (1): /dvs/git/dirty/git-master_linux/deepstream/sdk/src/gst-plugins/gst-nvinfer/gstnvinfer.cpp(812): gst_nvinfer_start (): /GstPipeline:pipeline0/GstNvInfer:primary-inference:
Config file path: /tmp/tmpcm6h3n8z, NvDsInfer Error: NVDSINFER_CONFIG_FAILED

Would you by chance have a pointer, what might go wrong here?

OK, one step ahead. For some reasons batch-size=3 seems to be a problem. I’m using this, because I’m having 3 input cameras. To move forward I changed it to batch-size=1 indeed, the engine file was created.

ssd-mobilenet.onnx_b1_gpu0_fp16.engine

Not sure, why it doesn’t work with 3.

It crashed later, complaining to be unable to parse bboxes, which is for sure caused by the fact, that I ignored the extra handler:

Using winsys: x11 
0:00:01.226569907 15513      0xd643640 INFO                 nvinfer gstnvinfer.cpp:619:gst_nvinfer_logger:<primary-inference> NvDsInferContext[UID 1]: Info from NvDsInferContextImpl::buildModel() <nvdsinfer_context_impl.cpp:1716> [UID = 1]: Trying to create engine from model files
----------------------------------------------------------------
Input filename:   /home/ubuntu/dragonfly-safety/jetson-inference/models/primary-detector-nano/ssd-mobilenet.onnx
ONNX IR version:  0.0.6
Opset version:    9
Producer name:    pytorch
Producer version: 1.6
Domain:           
Model version:    0
Doc string:       
----------------------------------------------------------------
WARNING: [TRT]: onnx2trt_utils.cpp:220: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.
WARNING: [TRT]: onnx2trt_utils.cpp:246: One or more weights outside the range of INT32 was clamped
WARNING: [TRT]: onnx2trt_utils.cpp:246: One or more weights outside the range of INT32 was clamped
WARNING: [TRT]: onnx2trt_utils.cpp:246: One or more weights outside the range of INT32 was clamped
INFO: [TRT]: Some tactics do not have sufficient workspace memory to run. Increasing workspace size may increase performance, please check verbose output.
INFO: [TRT]: Detected 1 inputs and 4 output network tensors.
0:02:58.653771316 15513      0xd643640 INFO                 nvinfer gstnvinfer.cpp:619:gst_nvinfer_logger:<primary-inference> NvDsInferContext[UID 1]: Info from NvDsInferContextImpl::buildModel() <nvdsinfer_context_impl.cpp:1749> [UID = 1]: serialize cuda engine to file: /home/ubuntu/dragonfly-safety/jetson-inference/models/primary-detector-nano/ssd-mobilenet.onnx_b1_gpu0_fp16.engine successfully
INFO: [Implicit Engine Info]: layers num: 3
0   INPUT  kFLOAT input_0         3x300x300       
1   OUTPUT kFLOAT scores          3000x9          
2   OUTPUT kFLOAT boxes           3000x4          

ERROR: [TRT]: INVALID_ARGUMENT: Cannot find binding of given name: grid
0:02:58.674788473 15513      0xd643640 WARN                 nvinfer gstnvinfer.cpp:616:gst_nvinfer_logger:<primary-inference> NvDsInferContext[UID 1]: Warning from NvDsInferContextImpl::checkBackendParams() <nvdsinfer_context_impl.cpp:1670> [UID = 1]: Could not find output layer 'grid' in engine
0:02:58.772456952 15513      0xd643640 INFO                 nvinfer gstnvinfer_impl.cpp:313:notifyLoadModelStatus:<primary-inference> [UID 1]: Load new model:/tmp/tmpr93o712f sucessfully
0:02:59.703485356 15513      0xd1dcf20 ERROR                nvinfer gstnvinfer.cpp:613:gst_nvinfer_logger:<primary-inference> NvDsInferContext[UID 1]: Error in NvDsInferContextImpl::parseBoundingBox() <nvdsinfer_context_impl_output_parsing.cpp:59> [UID = 1]: Could not find output coverage layer for parsing objects
0:02:59.704273968 15513      0xd1dcf20 ERROR                nvinfer gstnvinfer.cpp:613:gst_nvinfer_logger:<primary-inference> NvDsInferContext[UID 1]: Error in NvDsInferContextImpl::fillDetectionOutput() <nvdsinfer_context_impl_output_parsing.cpp:733> [UID = 1]: Failed to parse bboxes
Segmentation fault (core dumped)

What am I supposed to do with these errors?

Could not find output layer 'grid' in engine
Could not find output coverage layer for parsing objects
Failed to parse bboxes

…and yet another step ahead:

I figured out that the parameter output-blob-names would probably have to be output-blob-names=boxes. Then just the problem with the “Failed to parse bboxes” remains, which probably really needs that custom parser…

EDIT: Maybe a better fit would be

output-blob-names=boxes;scores

EDIT 2: Not sure, if I now have moved forward or back again. I found a bbox parser in /opt/nvidia/deepstream/deepstream-5.1/sources/objectDetector_SSD/nvdsinfer_custom_impl_ssd and patched that a bit (namely the CUDA version in the Makefile and the number of classes in nvdsparsebbox_ssd.cpp. I built it, copied the lib to my working directory and adapted the config. But I’m still having the “Failed to parse bbox crash”.

My current config (the lib is loaded, I’m pretty sure):

[property]
workspace-size=800
gpu-id=0
model-color-format=0
net-scale-factor=0.003921569790691137
onnx-file=/home/ubuntu/dragonfly-safety/jetson-inference/models/primary-detector-nano/ssd-mobilenet.onnx
labelfile-path=/home/ubuntu/dragonfly-safety/jetson-inference/models/primary-detector-nano/labels_onnx.txt
model-engine-file=/home/ubuntu/dragonfly-safety/jetson-inference/models/primary-detector-nano/ssd-mobilenet.onnx_b1_gpu0_fp16.engine
batch-size=1
network-mode=2
num-detected-classes=10
maintain-aspect-ratio=1
gie-unique-id=1
is-classifier=0
output-blob-names=boxes;scores
parse-bbox-func-name=NvDsInferParseCustomSSD
custom-lib-path=/home/ubuntu/dragonfly-safety/jetson-inference/models/primary-detector-nano/libnvdsinfer_custom_impl_ssd.so

There is still a mile to go :)

EDIT 3: The error is btw: “Could not find NMS layer buffer while parsing” and it comes from the lib… Whatever this means

EDIT 4: OK, now I’m stuck. Following this thread Onnx model on deepstream5.0: Nvinfer error: Could not find NMS layer buffer while parsing - #4 by haneesh24199 I changed the two appearances of NMS to boxes and scores and now the whole thing crashes with segfault w/o further notice… :(

BTW: I also fixed the wrong number of “num-detected-classes” to 9 meanwhile…

EDIT 5: Solution for EDIT 4: Onnx model on deepstream5.0: Nvinfer error: Could not find NMS layer buffer while parsing - #13 by foreverneilyoung

Hi @foreverneilyoung, it seems you have another thread going on the DeepStream forum about this, which is good because they know a lot more about DeepStream than I over there :)

You may have found this, but this is where the outputs of the ONNX-based SSD-Mobilenet are interpreted:

https://github.com/dusty-nv/jetson-inference/blob/2fb798e3e4895b51ce7315826297cf321f4bd577/c/detectNet.cpp#L818

  • boxes layer is a buffer of float4’s (left, top, right, bottom) with coordinates between [0,1] - so they need multiplied by the image width/height

  • scores layer is a buffer of floats, num_boxes * num_classes long. Each box has a confidence value for each class - the maximum confidence value is the class for that box. The confidences should be thresholded, because not every box is actually a detection.

It looks like your DeepStream custom bounding box parser would need modified to reflect the same parsing as above. Currently I think it is setup for the ‘TensorFlow way’, which is where the score + bounding box coordinates are all output in the same layer. You can see in my code where I actually have that way implemented too, in order to run the TensorFlow UFF version of the models.

Hi @foreverneilyoung, it seems you have another thread going on the DeepStream forum about this, which is good because they know a lot more about DeepStream than I over there :)

I completely agree. Today was a crash course…:)

You may have found this, but this is where the outputs of the ONNX-based SSD-Mobilenet are interpreted:

No not yet. I was looking for that, but was lost…

Cool. That should give me a new kick start. Thanks. I agree with your conclusion. I have checked a lot of the DeepStream samples; none of them is a perfect match.

It looks like your DeepStream custom bounding box parser would need modified to reflect the same parsing as above. Currently I think it is setup for the ‘TensorFlow way’, which is where the score + bounding box coordinates are all output in the same layer. You can see in my code where I actually have that way implemented too, in order to run the TensorFlow UFF version of the models.

Maybe, at least it doesn’t fit.

Thanks for the pointer. Will work on this.

@dusty_nv Short additional question: What is mOutput in your code? I see a lot of references, but can’t currently see, where it is assigned from what

mOutput is an array of output tensors from TensorRT - it is defined in the tensorNet base class.

Is that identical to the buffer pointer provided by the bboxLayerIndex?

float *detectionOut = (float *) outputLayersInfo[bboxLayerIndex].buffer;

I wonder if it was because the ONNX was exported with the default of --batch-size=1
https://github.com/dusty-nv/pytorch-ssd/blob/e7b5af50a157c50d3bab8f55089ce57c2c812f37/onnx_export.py#L25

Try running onnx_export.py with --batch-size=3 instead.