fcis_resnet101.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:chainer-fcis 作者: knorth55 项目源码 文件源码
def init_weight(self, resnet101=None):
        if resnet101 is None:
            resnet101 = chainer.links.ResNet101Layers(pretrained_model='auto')

        n_layer_dict = {
            'res2': 3,
            'res3': 4,
            'res4': 23,
            'res5': 3
        }

        def copy_conv(conv, orig_conv):
            assert conv is not orig_conv
            assert conv.W.array.shape == orig_conv.W.array.shape
            conv.W.array[:] = orig_conv.W.array

        def copy_bn(bn, orig_bn):
            assert bn is not orig_bn
            assert bn.gamma.array.shape == orig_bn.gamma.array.shape
            assert bn.beta.array.shape == orig_bn.beta.array.shape
            assert bn.avg_var.shape == orig_bn.avg_var.shape
            assert bn.avg_mean.shape == orig_bn.avg_mean.shape
            bn.gamma.array[:] = orig_bn.gamma.array
            bn.beta.array[:] = orig_bn.beta.array
            bn.avg_var[:] = orig_bn.avg_var
            bn.avg_mean[:] = orig_bn.avg_mean

        def copy_bottleneck(bottle, orig_bottle, n_conv):
            for i in range(0, n_conv):
                conv_name = 'conv{}'.format(i + 1)
                conv = getattr(bottle, conv_name)
                orig_conv = getattr(orig_bottle, conv_name)
                copy_conv(conv, orig_conv)

                bn_name = 'bn{}'.format(i + 1)
                bn = getattr(bottle, bn_name)
                orig_bn = getattr(orig_bottle, bn_name)
                copy_bn(bn, orig_bn)

        def copy_block(block, orig_block, res_name):
            n_layer = n_layer_dict[res_name]
            bottle = getattr(block, '{}_a'.format(res_name))
            copy_bottleneck(bottle, orig_block.a, 4)
            for i in range(1, n_layer):
                bottle = getattr(block, '{0}_b{1}'.format(res_name, i))
                orig_bottle = getattr(orig_block, 'b{}'.format(i))
                copy_bottleneck(bottle, orig_bottle, 3)

        copy_conv(self.res1.conv1, resnet101.conv1)
        copy_bn(self.res1.bn1, resnet101.bn1)
        copy_block(self.res2, resnet101.res2, 'res2')
        copy_block(self.res3, resnet101.res3, 'res3')
        copy_block(self.res4, resnet101.res4, 'res4')
        copy_block(self.res5, resnet101.res5, 'res5')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号