ResNet.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:kaggle-dsg-qualification 作者: Ignotus 项目源码 文件源码
def __init__(self, in_channels, out_channels, ksize=3, fiber_map='id',
                 stride=1, pad=1, wscale=1, bias=0, nobias=False, use_cudnn=True, initialW=None, initial_bias=None):

        assert ksize % 2 == 1

        assert pad == (ksize - 1) // 2

        super(ResBlock, self).__init__(
            bn1=L.BatchNormalization(in_channels),
            conv1=L.Convolution2D(in_channels, out_channels, ksize, stride, pad, wscale),
            bn2=L.BatchNormalization(out_channels),
            conv2=L.Convolution2D(out_channels, out_channels, ksize, 1, pad, wscale),
        )
        if fiber_map == 'id':
            assert in_channels == out_channels
            self.fiber_map = F.identity
        elif fiber_map == 'linear':
            self.add_link('fiber_map', L.Convolution2D(in_channels, out_channels, 1, 2, 0, wscale))
        else:
            raise ValueError('Unimplemented fiber map {}'.format(fiber_map))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号