Description
When I depoly CodeFormer with tensorrt, got an error “could not find plugin: ScatterElements version: 1”. And with the limit of the environment, I can’t update tensorrt to 8.2+。
Then I try to convert scatterElement to scatterNd, but got an error: Invalid Node: concat_1022
I locate the problem in this code:
min_tmp = torch.ones(int(indices.shape[0]) * int(indices.shape[1]) * self.codebook_size).to(indices)
indices.flatten()
tmp = torch.add(self.tmp, indices)
min_tmp[tmp] = 1
min_tmp = min_tmp.view(256, self.codebook_size)
The origin python code:
ndices = indices.view(-1, 1) # [batch * 256, 1]
min_encodings = torch.zeros(int(indices.shape[0]), self.codebook_size).to(indices)
min_encodings.scatter_(1, indices, 1)
Thanks for your help.
Environment
TensorRT Version: 8.0.1.6
Relevant Files
The onnx file CodeFormer.onnx - Google Drive