Description
Hello all.
I was exporting a model from Pytorch to Tensorrt and have had issues with the clamp method. The model looks like the following:
class GaussianPolicyCNNActuatorsTrainv1(nn.Module):
def __init__(self,
num_state_channels=5,
mask_valid_actuators=None,
hidden_dim=64,
num_layers=3,
activation="relu",
kernel_size=3
):
super(GaussianPolicyCNNActuatorsTrainv1, self).__init__()
# Layers
self.pi_list = nn.ModuleList()
self.pi_list.append(
torch.nn.Conv2d(num_state_channels, hidden_dim, kernel_size=kernel_size, stride=1, padding=1))
for i in range(num_layers - 1):
self.pi_list.append(nn.Conv2d(hidden_dim, hidden_dim, kernel_size=kernel_size, stride=1, padding=1))
self.mean_out = torch.nn.Conv2d(hidden_dim, 1, kernel_size=kernel_size, stride=1, padding=1)
self.log_std_out = torch.nn.Conv2d(hidden_dim, 1, kernel_size=kernel_size, stride=1, padding=1)
# Activations
if activation == "relu":
self.activation = F.relu
elif activation == "leaky_relu":
self.activation = F.leaky_relu
else:
raise NotImplementedError
# Mask valid actuators
self.mask_valid_actuators = nn.Parameter(
torch.tensor(mask_valid_actuators.reshape(-1, mask_valid_actuators.shape[0], mask_valid_actuators.shape[1]),
dtype=torch.float32), requires_grad=False)
def forward(self, state):
for i in range(len(self.pi_list)):
state = self.activation(self.pi_list[i](state))
mean, log_std = self.mean_out(state), self.log_std_out(state)
log_std = torch.clamp(log_std, min=-20.0, max=2)
std = log_std.exp()
normal = Normal(mean, std)
x_t = normal.sample()
y_t = torch.tanh(x_t)
out = torch.cat([a, mean, std], dim=1) # Concatenating along the channel dimension.
return out
def to(self, device):
return super(GaussianPolicyCNNActuatorsTrainv1, self).to(device)
After exporting to onnx and then to tensorrt, I could not get the same results with an example input. Investigating I tried a different policy, one that outputs steps in between. Interestingly, this one worked!!
class GaussianPolicyCNNActuatorsTrain(nn.Module):
def __init__(self,
num_state_channels=5,
mask_valid_actuators=None,
hidden_dim=64,
num_layers=3,
activation="relu",
kernel_size=3
):
super(GaussianPolicyCNNActuatorsTrain, self).__init__()
# Layers
self.pi_list = nn.ModuleList()
self.pi_list.append(
torch.nn.Conv2d(num_state_channels, hidden_dim, kernel_size=kernel_size, stride=1, padding=1))
for i in range(num_layers - 1):
self.pi_list.append(nn.Conv2d(hidden_dim, hidden_dim, kernel_size=kernel_size, stride=1, padding=1))
self.mean_out = torch.nn.Conv2d(hidden_dim, 1, kernel_size=kernel_size, stride=1, padding=1)
self.log_std_out = torch.nn.Conv2d(hidden_dim, 1, kernel_size=kernel_size, stride=1, padding=1)
# Activations
if activation == "relu":
self.activation = F.relu
elif activation == "leaky_relu":
self.activation = F.leaky_relu
else:
raise NotImplementedError
# Mask valid actuators
self.mask_valid_actuators = nn.Parameter(
torch.tensor(mask_valid_actuators.reshape(-1, mask_valid_actuators.shape[0], mask_valid_actuators.shape[1]),
dtype=torch.float32), requires_grad=False)
def forward(self, state):
for i in range(len(self.pi_list)):
state = self.activation(self.pi_list[i](state))
mean = self.mean_out(state)
log_std = self.log_std_out(state)
log_std_clamped = torch.clamp(log_std, min=-20.0, max=2)
std = log_std_clamped.exp()
normal = Normal(mean, std)
x_t = normal.sample()
y_t = torch.tanh(x_t)
a = torch.mul(y_t, self.mask_valid_actuators)
out = torch.cat([a, mean, std, log_std, log_std_clamped], dim=1) # Concatenating along the channel dimension.
return out
def to(self, device):
return super(GaussianPolicyCNNActuatorsTrain, self).to(device)
After looking at the output of tensorrt I see the issue is in the clamp step. You can find the full logs below in layers_info_3channels.txt and layers_info_5channels.txt.
For the 3 channel one, tensorrt interprets redoes the layer as:
{'Name': '/log_std_out/Conv + /Clip',
'LayerType': 'CaskConvolution',
'Inputs': [{'Name': '/Relu_2_output_0',
'Location': 'Device',
'Dimensions': [1, 64, 28, 28],
'Format/Datatype': 'Channel major FP32 format where channel % 4 == 0'}],
'Outputs': [{'Name': '/Clip_output_0',
'Location': 'Device',
'Dimensions': [1, 1, 28, 28],
'Format/Datatype': 'Row major linear FP32'}],
'ParameterType': 'Convolution',
'Kernel': [3, 3],
'PaddingMode': 'kEXPLICIT_ROUND_DOWN',
'PrePadding': [1, 1],
'PostPadding': [1, 1],
'Stride': [1, 1],
'Dilation': [1, 1],
'OutMaps': 1,
'Groups': 1,
'Weights': {'Type': 'Float', 'Count': 576},
'Bias': {'Type': 'Float', 'Count': 1},
'HasBias': 1,
'HasReLU': 0,
'HasSparseWeights': 0,
'HasDynamicFilter': 0,
'HasDynamicBias': 0,
'HasResidual': 0,
'ConvXAsActInputIdx': -1,
'BiasAsActInputIdx': -1,
'ResAsActInputIdx': -1,
'Activation': 'CLIPPED_RELU',
'ParameterSubType': '[0.000000,2.000000]',
'TacticName': 'sm80_xmma_fprop_implicit_gemm_f32f32_tf32f32_f32_nhwckrsc_nchw_tilesize64x32x64_stage5_warpsize2x2x1_g1_tensor16x8x8_t1r3s3',
'TacticValue': '0xf5c5ffdea9383daa',
'StreamId': 1,
'Metadata': '[ONNX Layer: /log_std_out/Conv]\x1e[ONNX Layer: /Clip]'}
My impression is that uses a clipped relu with wrong values (0, 2) when I want to clip between -20 and 2.
Instead for the 5 channel one the clip appears as:
{'Name': 'PWN(/Clip)',
'LayerType': 'PointWiseV2',
'Inputs': [{'Name': '/mean_out/Conv || /log_std_out/Conv',
'Location': 'Device',
'Dimensions': [1, 1, 28, 28],
'Format/Datatype': 'Row major linear FP32'}],
'Outputs': [{'Name': 'output',
'Location': 'Device',
'Dimensions': [1, 1, 28, 28],
'Format/Datatype': 'Row major linear FP32'}],
'ParameterType': 'PointWise',
'ParameterSubType': 'PointWiseExpression',
'NbInputArgs': 1,
'InputArgs': ['arg0'],
'NbOutputVars': 1,
'OutputVars': ['var1'],
'NbParams': 0,
'Params': [],
'NbLiterals': 4,
'Literals': ['0.000000e+00f',
'1.000000e+00f',
'-2.000000e+01f',
'2.000000e+00f'],
'NbOperations': 2,
'Operations': ['auto const var0 = pwgen::iMin(literal3, arg0);',
'auto const var1 = pwgen::iMax(literal2, var0);'],
'TacticValue': '0x000000000000001c',
'StreamId': 0,
'Metadata': '[ONNX Layer: /Clip]'}
This one has the correct values of 2 and -20.
Note, I am sure that the issue has to do with the log_std and clamp because the mean obtains the exact same value.
Any idea? This seems like a bug.
layers_info_3channels.txt (14.3 KB)
layers_info_5channels.txt (17.5 KB)
Environment
TensorRT Version: 10.0.1
GPU Type: A100
Nvidia Driver Version: 550.54.15
CUDA Version: 12.4
CUDNN Version: 9.1.1
Operating System + Version: Rocky Linux 9.2 (Blue Onyx)
Python Version: ‘1.15.0’
PyTorch Version: ‘2.2.2+cu121’
OnnxVersion: ‘1.15.0’