Description
I have been working with a GoogleNet model from torchvision and I have converted it to a trt engine using GitHub - NVIDIA-AI-IOT/torch2trt: An easy to use PyTorch to TensorRT converter . I have converted the model to an engine file, which I read into C++. I have been able to infer correct results in the python API with torch2trt and tensorrt. I have been trying to understand how inference works and it seems buffers are allocated according to the bindings before inferring on them. Ideally, there should be an input and an output buffer or more in most cases. In my case, I only see an output buffer and hence the engine->getNbBindings as 1. It’s the same value while using the python API. My question is how does TensorRT set the number of bindings? And in my case, did I call the torch2trt in a wrong way? If so, what is the correct way? Can I still work with 1 binding in the C++ API. I have attached files below for better understanding.
Environment
TensorRT Version: 5.1.5
GPU Type: GeForce RTX 2080
Nvidia Driver Version: 440
CUDA Version: 10.0
CUDNN Version: 7.6.5
Operating System + Version: Ubuntu 18.04
Python Version (if applicable): Python 3.6.9
TensorFlow Version (if applicable):
PyTorch Version (if applicable):
Baremetal or Container (if container which image + tag): 1.4.0
Relevant Files
convert_to_trt.py - converts the checkpoint to a serialized engine file.
checkpoint/DistributedDataParallel_200_80.pt - Original PyTorch checkpoint
checkpoint/trt.engine - The trt file that was already converted. (convert_to_trt.py converts another one)
include/torch2trt.h - header file
src/torch2trt.cpp - source file
Steps To Reproduce
- tar -xvf googlenet.tar.gz
- Run scripts/convert_to_trt.py by ensuring you have the libraries installed
- Create build directory
- cd build
- cmake … && make -j8
- ./googlenet_node
Edit: I just realized I cannot upload a tar file. Hence, sharing a Google Drive link
https://drive.google.com/open?id=15VJCxSoa0VyQxNTUtos7i45_M9etWP_s
Directory Structure:
Parent Directory: GoogleNet
------ include
------ ------ torch2trt.h
------ build
------ src
------ ------ torch2trt.cpp
------ scripts
------ ------ convert_to_trt.oy
------ checpoint
------ ------ DistributedDataParallel_200_80.pt
------ ------ trt.engine