residual_block_2d.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:gconv_experiments 作者: tscohen 项目源码 文件源码
def __init__(self, in_channels, out_channels, ksize=3, fiber_map='id', conv_link=L.Convolution2D,
                 stride=1, pad=1, wscale=1):

        assert ksize % 2 == 1

        if not pad == (ksize - 1) // 2:
            raise NotImplementedError()

        super(ResBlock2D, self).__init__(
            bn1=L.BatchNormalization(in_channels),
            conv1=conv_link(
                in_channels=in_channels, out_channels=out_channels, ksize=ksize, stride=stride, pad=pad, wscale=wscale),
            bn2=L.BatchNormalization(out_channels),
            conv2=conv_link(
                in_channels=out_channels, out_channels=out_channels, ksize=ksize, stride=1, pad=pad, wscale=wscale)
        )

        if fiber_map == 'id':
            if not in_channels == out_channels:
                raise ValueError('fiber_map cannot be identity when channel dimension is changed.')
            self.fiber_map = F.identity
        elif fiber_map == 'zero_pad':
            raise NotImplementedError()
        elif fiber_map == 'linear':
            fiber_map = conv_link(
                in_channels=in_channels, out_channels=out_channels, ksize=1, stride=stride, pad=0, wscale=wscale)
            self.add_link('fiber_map', fiber_map)
        else:
            raise ValueError('Unknown fiber_map: ' + str(type))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号