def init_weight(self, resnet101=None):
if resnet101 is None:
resnet101 = chainer.links.ResNet101Layers(pretrained_model='auto')
n_layer_dict = {
'res2': 3,
'res3': 4,
'res4': 23,
'res5': 3
}
def copy_conv(conv, orig_conv):
assert conv is not orig_conv
assert conv.W.array.shape == orig_conv.W.array.shape
conv.W.array[:] = orig_conv.W.array
def copy_bn(bn, orig_bn):
assert bn is not orig_bn
assert bn.gamma.array.shape == orig_bn.gamma.array.shape
assert bn.beta.array.shape == orig_bn.beta.array.shape
assert bn.avg_var.shape == orig_bn.avg_var.shape
assert bn.avg_mean.shape == orig_bn.avg_mean.shape
bn.gamma.array[:] = orig_bn.gamma.array
bn.beta.array[:] = orig_bn.beta.array
bn.avg_var[:] = orig_bn.avg_var
bn.avg_mean[:] = orig_bn.avg_mean
def copy_bottleneck(bottle, orig_bottle, n_conv):
for i in range(0, n_conv):
conv_name = 'conv{}'.format(i + 1)
conv = getattr(bottle, conv_name)
orig_conv = getattr(orig_bottle, conv_name)
copy_conv(conv, orig_conv)
bn_name = 'bn{}'.format(i + 1)
bn = getattr(bottle, bn_name)
orig_bn = getattr(orig_bottle, bn_name)
copy_bn(bn, orig_bn)
def copy_block(block, orig_block, res_name):
n_layer = n_layer_dict[res_name]
bottle = getattr(block, '{}_a'.format(res_name))
copy_bottleneck(bottle, orig_block.a, 4)
for i in range(1, n_layer):
bottle = getattr(block, '{0}_b{1}'.format(res_name, i))
orig_bottle = getattr(orig_block, 'b{}'.format(i))
copy_bottleneck(bottle, orig_bottle, 3)
copy_conv(self.res1.conv1, resnet101.conv1)
copy_bn(self.res1.bn1, resnet101.bn1)
copy_block(self.res2, resnet101.res2, 'res2')
copy_block(self.res3, resnet101.res3, 'res3')
copy_block(self.res4, resnet101.res4, 'res4')
copy_block(self.res5, resnet101.res5, 'res5')
评论列表
文章目录