#Generate a model that upsamples the image import torch import torch.nn as nn class TestModel(nn.Module): def __init__(self, channels_deconv_in, channels_deconv_out, scale = 255, stride=2): super().__init__() channels = 3 self.scale = torch.ones([1,channels,1,1],dtype = torch.float)*scale self.preconv = nn.Conv2d(channels,channels_deconv_in, kernel_size = 3, padding=1) conv_weight = torch.zeros([channels_deconv_in,channels,3,3]) for c in range(channels): conv_weight[-c,-c,1,1] = 1 self.preconv.state_dict()['weight'].data.copy_(conv_weight) self.preconv.state_dict()['bias'].data.copy_(torch.zeros([channels_deconv_in],dtype=torch.float)) self.deconv = nn.ConvTranspose2d(channels_deconv_in, channels_deconv_out, kernel_size=stride, padding=0, stride=stride) deconv_weight = torch.zeros([channels_deconv_in,channels_deconv_out,stride,stride]) for c in range(channels): deconv_weight[-c,-c,:,:] = 1 self.deconv.state_dict()['weight'].data.copy_(deconv_weight) self.deconv.state_dict()['bias'].data.copy_(torch.zeros([channels_deconv_out],dtype=torch.float)) self.conv = nn.Conv2d(channels_deconv_out,channels, kernel_size=stride+1, padding=stride//2) conv_weight = torch.zeros([channels,channels_deconv_out,stride+1,stride+1]) for c in range(channels): conv_weight[-c,-c,:,:] = 1/((stride+1)**2) self.conv.state_dict()['weight'].data.copy_(conv_weight) self.conv.state_dict()['bias'].data.copy_(torch.zeros([channels],dtype=torch.float)) def forward(self, x): x = x/self.scale x = self.preconv(x) x = self.deconv(x) x = self.conv(x) x = x*self.scale return x if __name__ == '__main__': #DLA has incorrect results when channels_deconv_out > 16 test_model = TestModel(channels_deconv_in=256,channels_deconv_out=16, stride=4) test_model = test_model.eval().cpu() print(test_model) dummy_input = torch.autograd.Variable(torch.randn(1,3,256//4,512//4)) test_model(dummy_input) torch.onnx.export(test_model, dummy_input, "deconv_test.onnx", export_params=True)