class build_co_train_model_depthnet(nn.Module): def __init__(self, module1, module2, fine_tune, gamma=0.5, progressive_IKD_factor=False): super(build_co_train_model_depthnet, self).__init__() # dictionary on MbVNet skipped connections self.gamma = gamma self.progressive_IKD_factor = progressive_IKD_factor self.names = [] self.num_convs = 0 self.fine_tune = fine_tune for i, ((n1, m1), (n2, m2)) in enumerate(zip(module1.named_modules(), module2.named_modules())): if (isinstance(m1, nn.Conv2d) or isinstance(m1, InPlaceABN) or isinstance(m1, InPlaceABNSync) or isinstance(m1, nn.ConvTranspose2d))\ and n1.split('_')[0] == 'features': self.names.append(n1[:9] + n1[9:].replace('_', '.')) m_new = deepcopy(m2) for param in m1.parameters(): param.requires_grad = False self.add_module('_'.join(n1.split('.')) + '_teacher', m1) self.add_module('_'.join(n2.split('.')), m_new) self.insert_linear_layer(n1, m1, n2, m2) #Insert latent-up-top layers in appropriate order n1 = 'latlayer0_teacher' n2 = 'latlayer0' m1 = module1._modules[n2] m2 = module2._modules[n2] self.add_module(n1, m1) self.add_module(n2, m2) self.insert_linear_layer(n1, m1, n2, m2) self.names.append(n2) triplet = ['latlayer', 'uplayer', 'toplayer'] for i in range(1, 5): for layer in triplet: #In drop layer DephNet there are no top/lat 3 and 4 layers if (i==3 or i==4) and (layer=='latlayer' or layer=='toplayer'): continue n1 = layer + str(i) + '_teacher' n2 = layer + str(i) m1 = module1._modules[n2] m2 = module2._modules[n2] self.add_module(n1, m1) self.add_module(n2, m2) self.insert_linear_layer(n1, m1, n2, m2) self.names.append(n2) self.add_module('conv5_teacher', module1._modules['conv5']) self.add_module('conv5', module2._modules['conv5']) self.names.append('conv5') return def insert_linear_layer(self, n1, m1, n2, m2): if isinstance(m1, nn.Conv2d): self.num_convs += 1 d1 = m1.out_channels d2 = m2.out_channels if d1 != d2: print(f'Dimension mismatch for layer {n1} : {d1} and {d2}. Using IKD with linear projection') M = nn.Linear(d2, d1, bias=False) for param in M.parameters(): param.requires_grad = False self.add_module('_'.join(n2.split('.')) + '_M', M) #DeptNet has 2 sorts of skip connections, in the residual blocks and #also between the MbVNet and the FPN. Save the tensors along the forward path def forward(self, x): residual_skips = {#'features_2.0.conv.2.1': [], 'features_3.0.conv.2.1': [], 'features_3.1.conv.2.1': [], 'features_4.0.conv.2.1': [], 'features_4.1.conv.2.1': [], 'features_4.2.conv.2.1': [], 'features_5.0.conv.2.1': [], 'features_5.1.conv.2.1': [], 'features_6.0.conv.2.1': [], 'features_6.1.conv.2.1': []} # MbVNet layers names to store tensors feats_names = [#'features_1.0.conv.1.1', 'features_2.1.conv.2.1', 'features_3.2.conv.2.1', 'features_5.2.conv.2.1'] # FPN layers names that are caluclated separately, not needed to forward them again layers_to_skip = ['uplayer1', 'uplayer2'] mod_feats = [] list_layers = extract_layers(self) layer_idx = 0 names_idx = 0 x = x.permute(0, 3, 1, 2).contiguous() while layer_idx in range(len(list_layers)): if list_layers[layer_idx] == '': layer_idx = layer_idx + 1 continue else: #layer = extract_layer(self, list_layers[layer_idx]) layer = getattr(self, list_layers[layer_idx]) layer_idx = layer_idx + 1 if self.names[names_idx] == 'latlayer0': x = quant_nn.TensorQuantizer(quant_nn.QuantConv2d.default_quant_desc_input)(layer(x)) #x = layer(x) elif self.names[names_idx] == 'latlayer1': x = quant_nn.TensorQuantizer(quant_nn.QuantConv2d.default_quant_desc_input)(layer(mod_feats[1])) + \ quant_nn.TensorQuantizer(quant_nn.QuantConv2d.default_quant_desc_input)(self.uplayer1(x)) #x = layer(mod_feats[1]) + self.uplayer1(x) elif self.names[names_idx] == 'toplayer1': x = quant_nn.TensorQuantizer(quant_nn.QuantConv2d.default_quant_desc_input)(self.toplayer1(x)) #x = self.toplayer1(x) elif self.names[names_idx] == 'latlayer2': x = quant_nn.TensorQuantizer(quant_nn.QuantConv2d.default_quant_desc_input)(layer(mod_feats[0])) + \ quant_nn.TensorQuantizer(quant_nn.QuantConv2d.default_quant_desc_input)(self.uplayer2(x)) #x = layer(mod_feats[0]) + self.uplayer2(x) elif self.names[names_idx] == 'toplayer2': x = quant_nn.TensorQuantizer(quant_nn.QuantConv2d.default_quant_desc_input)(self.toplayer2(x)) #x = self.toplayer2(x) elif self.names[names_idx] == 'uplayer3': x = self.uplayer3(x) elif self.names[names_idx] == 'uplayer4': x = self.uplayer4(x) elif self.names[names_idx] not in layers_to_skip: x = layer(x) # Search for Tensor from the previous for key in residual_skips: if len(residual_skips[key]) != 0: if residual_skips[key][0][1] == self.names[names_idx]: x = quant_nn.TensorQuantizer(quant_nn.QuantConv2d.default_quant_desc_input)(x) + quant_nn.TensorQuantizer(quant_nn.QuantConv2d.default_quant_desc_input)(residual_skips[key][0][0]) #x = x + residual_skips[key][0][0] # deal with the residual skipped connections - save Tensor if self.names[names_idx] in residual_skips.keys(): residual_skips[self.names[names_idx]].append((x, self.names[names_idx + 6])) # 6 for ABN # deal with the mbvnet to fpn skipped connections if self.names[names_idx] in feats_names: mod_feats.append(x) names_idx = names_idx + 1 x = x.permute(0, 2, 3, 1).contiguous() depth = x[:, :, :, 0] error = x[:, :, :, 1] rgb_feats = x[:, :, :, 2:].contiguous() return depth, error, rgb_feats