Hi, I’m trying to use the CustomQKVToContextPluginDynamic Plugin in my TensorRT engine, but failed in some cases.
- For plugin_version=1, type_id=0, everything works fine.
- For plugin_version=1, type_id=1, trtexec raises
Error[9]: [pluginV2Builder.cpp::reportPluginError::23] Error Code 9: Internal Error (/CustomQKVToContextPluginDynamic: could not find any supported formats consistent with input/output data types)
. - For plugin_version>1, everything does not work.
nvidia docker container 22.12
Relevant Files
Steps To Reproduce
I have the following code:
import torch
import torch.nn as nn
# The yaml file says that version 3 is not supported yet.
class CustomQKVToContextPluginDynamic(torch.autograd.Function):
def forward(ctx, input, hidden_size, num_heads):
return input
def symbolic(g, input, hidden_size, num_heads):
return g.op("CustomQKVToContextPluginDynamic", input, plugin_version_s='1', type_id_i=0, hidden_size_i=hidden_size, num_heads_i=num_heads, has_mask_i=False)
class MyModule(nn.Module):
def __init__(self, hidden_size, num_heads):
assert hidden_size % num_heads == 0
self.hidden_size = hidden_size
self.num_heads = num_heads
self.size_per_head = hidden_size // num_heads
self.Wq = nn.Linear(self.hidden_size, self.hidden_size)
self.Wk = nn.Linear(self.hidden_size, self.hidden_size)
self.Wv = nn.Linear(self.hidden_size, self.hidden_size)
def forward(self, x):
# shape of x (seq_len, batch_size, hidden_size)
# output (seq_len, batch_size, hidden_size)
Q = self.Wq(x)
K = self.Wk(x)
V = self.Wv(x)
qkv =[Q, K, V], dim=2)
qkv = qkv.view(x.size(0), x.size(1), 3, self.num_heads, self.size_per_head)
qkv = qkv.transpose(2, 3).contiguous().view(x.size(0), x.size(1), 3*self.hidden_size, 1, 1)
return CustomQKVToContextPluginDynamic.apply(qkv, self.hidden_size, self.num_heads).select(-1, 0).select(-1, 0)
model = MyModule(768, 8).cuda()#.half()
input = torch.randn(512, 2, 768).cuda()#.half()
from torch.onnx import OperatorExportTypes
torch.onnx.export(model, (input,), 'test.onnx', operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH, input_names=['input_0'], output_names=['output_0'])
which can output an onnx file, then use trtexec to transform it into an engine.