def _construct_diff_model(self, base_model, keep_rgb=False):
# modify the convolution layers
# Torch models are usually defined in a hierarchical way.
# nn.modules.children() return all sub modules in a DFS manner
modules = list(self.base_model.modules())
first_conv_idx = filter(lambda x: isinstance(modules[x], nn.Conv2d), list(range(len(modules))))[0]
conv_layer = modules[first_conv_idx]
container = modules[first_conv_idx - 1]
# modify parameters, assume the first blob contains the convolution kernels
params = [x.clone() for x in conv_layer.parameters()]
kernel_size = params[0].size()
if not keep_rgb:
new_kernel_size = kernel_size[:1] + (3 * self.new_length,) + kernel_size[2:]
new_kernels = params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous()
else:
new_kernel_size = kernel_size[:1] + (3 * self.new_length,) + kernel_size[2:]
new_kernels = torch.cat((params[0].data, params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous()),
1)
new_kernel_size = kernel_size[:1] + (3 + 3 * self.new_length,) + kernel_size[2:]
new_conv = nn.Conv2d(new_kernel_size[1], conv_layer.out_channels,
conv_layer.kernel_size, conv_layer.stride, conv_layer.padding,
bias=True if len(params) == 2 else False)
new_conv.weight.data = new_kernels
if len(params) == 2:
new_conv.bias.data = params[1].data # add bias if neccessary
layer_name = list(container.state_dict().keys())[0][:-7] # remove .weight suffix to get the layer name
# replace the first convolution layer
setattr(container, layer_name, new_conv)
return base_model
评论列表
文章目录