Hi all! Currently I’m developing with the Jetson Nano, and I’m looking for advice in regards to improving inference performance. I’m using Sagemaker as my training environment for SSD object detection with Resnet-50 as my base network, which exports .params and .json files for mxnet. I’ve built mxnet on the nano using the autoinstaller from this forum, and I’ve been able to infer via usb webcam by more or less following this guide:
That said, my inference speed is really slow, i.e. it takes around 4-5 seconds per frame with an input size of 512x512. I’ve already tried converting my weights to a different architecture via onnx and mmdnn, but my custom model had operators that were not supported by either format so it looks like I’m stuck with mxnet. The mxnet website says that it has tensorrt integration with mxnet but I can’t find any good examples of that anywhere online. The one on the mxnet website is at best confusing and doesn’t help me in my particular use case
One thing that seems to be holding me back is that I’m only able to infer using the cpu. According to the mxnet website, to use the gpu all you have to do is change ctx=cpu() to ctx=gpu(), and make sure that your data is converted to float32 before inputting it (MXNet Python inference crash when copy from CPU to GPU · Issue #13332 · apache/incubator-mxnet · GitHub). However, when I do that, it still crashes my Jetson because it seems to run out of memory. Does this have anything to do with the custom build of mxnet for the Nano? Otherwise why would it do that?
Any suggestions are welcome and appreciated!
Here’s my code:
import mxnet as mx
import numpy as np
import cv2, os, urllib, argparse, time
from collections import namedtuple
Batch = namedtuple('Batch', ['data'])
#array of object labels for custom network
object_categories = ['object 1','object 2']
#load model
""" important: make sure that -symbol.json and -0000.params are in the format network-prefix-symbol.json and network-prefix-0000.params and are located in current directory """
class ImagenetModel(object):
def __init__(self, synset_path, network_prefix, params_url=None, symbol_url=None, synset_url=None, context=mx.cpu(), label_names=['prob_label'], input_shapes=[('data', (1,3,10,10))]):
# Load the network parameters from default epoch 0
sym, arg_params, aux_params = mx.model.load_checkpoint(network_prefix, 0)
# Load the network into an MXNet module and bind the corresponding parameters
self.mod = mx.mod.Module(symbol=sym, label_names=label_names, context=context)
self.mod.bind(for_training=False, data_shapes= input_shapes)
self.mod.set_params(arg_params, aux_params)
self.camera = None
def predict_from_cam(self, reshape=(512, 512), N=50):
topN = []
# Switch RGB to BGR format (which ImageNet networks take)
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
if img is None:
return topN
# Resize image to fit network input
img = cv2.resize(img, reshape)
img = np.swapaxes(img, 0, 2)
img = np.swapaxes(img, 1, 2)
img = img[np.newaxis, :]
# Run forward on the image
prob = self.mod.get_outputs()[0].asnumpy()
prob = np.squeeze(prob)
global results
results = [prob[i].tolist() for i in range(100)]
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="pull and load pre-trained resnet model to classify one image")
parser.add_argument('--img', type=str, default='cam', help='input image for classification, if this is cam it captures from the webcam')
parser.add_argument('--prefix', type=str, default='model_algo_1', help='the prefix of the pre-trained model')
parser.add_argument('--label-name', type=str, default='softmax_label', help='the name of the last layer in the loaded network (usually softmax_label)')
parser.add_argument('--synset', type=str, default='synset.txt', help='the path of the synset for the model')
args = parser.parse_args()
mod = ImagenetModel(args.synset, args.prefix, label_names=[args.label_name])
print ("predicting on "+args.img)
if args.img == "cam":
vid = cv2.VideoCapture(0)
ret, frame = vid.read()
cv2.imshow('frame', frame)
#wait x ms, search for escape keypress on cv2 frame
if cv2.waitKey (1000) & 0xFF == ord('q'):