Inference with TensorRT model -- PyTorch BERT model > ONNX > TensorRT > Inference?

I have converted my BERT-trained model from PyTorch to ONNX and from ONNX to TensorRt. The challenge I am having now is getting the TensorRt model to return the probability values given an input text. When using the PyTorch or ONNX versions, the models take as input the input_ids and attention mask and yield the predictions (input_text_prediction --see below). Given that the TensorRt is the final conversion of the original PyTorch model, my intuition tells me that the TensorRt also needs to take the same inputs. I have been reading/following this guideline, but it seems like their approach is too complex, or am I oversimplifying the inference process using TensorRt?

Yes, I have done my Google/StackOverflow research and have not found an answer/guideline besides the above link. Any help/guidance resources are greatly appreciated.

 if ONNX:
        ort_inputs = {'input_ids':  encoding["input_ids"].cpu().reshape(1,  512).numpy(),
                                 'input_mask': encoding["attention_mask"].cpu().reshape(1, 512).numpy()}
        ort_outputs = session_name.run(None,  ort_inputs) ## session_name--> generally defined 
        input_text_prediction = list(ort_outputs[0][0])
if pytorch_model:
        input_text_prediction = model_name(encoding["input_ids"], encoding["attention_mask"])
        input_text_prediction = input_text_prediction.detach().numpy()[0]
1 Like

BERT works with tokens only, so I think you should use some autotokenizer before you put your data to model’s inputs.

Here is my example with onnxruntime:

from transformers import AutoTokenizer
import numpy as np
import onnx
import torch
import onnxruntime

class OnnxRuntimeModelPunct():
    def __init__(self):
        self.onnx_model = onnx.load('data/networks/asr/conformer/Punctuation_and_Capitalization.onnx')
        onnx.checker.check_model(self.onnx_model)
        self.ort_session = onnxruntime.InferenceSession('data/networks/asr/conformer/Punctuation_and_Capitalization.onnx', providers=['CUDAExecutionProvider'])
        self.inputs = self.ort_session.get_inputs()
        self.outputs = self.ort_session.get_outputs()
        
        for idx, binding in enumerate(self.inputs):
            print('')
            print(f"input {idx} - {binding.name}")
            print(f"   shape: {binding.shape}")
            print(f"   type:  {binding.type}")
            print('')

        self.tokenizer = AutoTokenizer.from_pretrained("DeepPavlov/distilrubert-tiny-cased-conversational-v1")
        
        
    def execute(self, inputs, return_dict=False, **kwargs):

        if isinstance(inputs, np.ndarray):
            inputs = [inputs]
        
        assert len(inputs) == len(self.inputs)
        
        if isinstance(inputs, (list,tuple)):
            inputs = {self.inputs[i].name : input for i, input in enumerate(inputs)}
        elif not isinstance(inputs, dict):        
            raise ValueError(f"inputs must be a list, tuple, or dict (instead got type '{type(inputs).__name__}')")
            
        outputs = self.ort_session.run(None, inputs)
        
        if return_dict:
            return {self.outputs[i].name : output for i, output in enumerate(outputs)}
            
        if len(outputs) == 1:
            return outputs[0]
        
        return outputs

    def call(self, query):

        encodings = self.tokenizer(
            text=query,
            padding='longest',
            truncation=True,
            max_length=1024,
            return_tensors='np',
            return_token_type_ids=True,
            return_overflowing_tokens=True,
            return_offsets_mapping=True,
            return_special_tokens_mask=True,
        )

        inputs = {}
        
        for input in self.inputs:
            if input.name not in encodings:
                raise ValueError(f"the encoded inputs from the tokenizer doesn't contain '{input.name}'")

            inputs[input.name] = encodings[input.name]
            #print("{}: {}".format(input.name,encodings[input.name]))
                    
        # run the model
        punct_logits, capit_logits = self.execute(inputs)
        punct_logits = normalize_logits(punct_logits)
        capit_logits = normalize_logits(capit_logits)

        punctpreds = []
        capitpreds = []
        punct_preds  = np.argmax(punct_logits, axis=-1)
        capit_preds  = np.argmax(capit_logits, axis=-1)

        for n in range(punct_preds.size):
            punctpreds.append(punct_preds[0][n])
            capitpreds.append(capit_preds[0][n])

        out = []
        for index, i in enumerate(self.tokenizer.encode(query)):
            if '[' not in self.tokenizer.decode([i]):
                if '##' not in self.tokenizer.decode([i]):
                    if capitpreds[index] == 1:
                        out.append(self.tokenizer.decode([i]).title())
                    else:
                        out.append(self.tokenizer.decode([i]))
                else:
                    out.pop()
                    out.append(self.tokenizer.decode([i]).replace("##",""))
                #print(i)
                out.append(punctuation[punctpreds[index]])
        output = "".join(out)
        sentence = output.split(". ",1)
        if len(sentence)==1:
            sentence_piece = sentence[0].split(" ")
            punct_position = len(sentence_piece)
        else:
            punct_position = 0
        return output, punct_position


punct = OnnxRuntimeModelPunct()