net.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:chainer-cifar 作者: dsanno 项目源码 文件源码
def __init__(self, depth=18, alpha=16, start_channel=16, skip=False):
        super(PyramidNet, self).__init__()
        channel_diff = float(alpha) / depth
        channel = start_channel
        links = [('bconv1', BatchConv2D(3, channel, 3, 1, 1))]
        skip_size = depth * 3 - 3
        for i in six.moves.range(depth):
            if skip:
                skip_ratio = float(i) / skip_size * 0.5
            else:
                skip_ratio = 0
            in_channel = channel
            channel += channel_diff
            links.append(('py{}'.format(len(links)), PyramidBlock(int(round(in_channel)), int(round(channel)),  skip_ratio=skip_ratio)))
        in_channel = channel
        channel += channel_diff
        links.append(('py{}'.format(len(links)), PyramidBlock(int(round(in_channel)), int(round(channel)), stride=2)))
        for i in six.moves.range(depth - 1):
            if skip:
                skip_ratio = float(i + depth) / skip_size * 0.5
            else:
                skip_ratio = 0
            in_channel = channel
            channel += channel_diff
            links.append(('py{}'.format(len(links)), PyramidBlock(int(round(in_channel)), int(round(channel)),  skip_ratio=skip_ratio)))
        in_channel = channel
        channel += channel_diff
        links.append(('py{}'.format(len(links)), PyramidBlock(int(round(in_channel)), int(round(channel)), stride=2)))
        for i in six.moves.range(depth - 1):
            if skip:
                skip_ratio = float(i + depth * 2 - 1) / skip_size * 0.5
            else:
                skip_ratio = 0
            in_channel = channel
            channel += channel_diff
            links.append(('py{}'.format(len(links)), PyramidBlock(int(round(in_channel)), int(round(channel)),  skip_ratio=skip_ratio)))
        links.append(('bn{}'.format(len(links)), L.BatchNormalization(int(round(channel)))))
        links.append(('_relu{}'.format(len(links)), F.ReLU()))
        links.append(('_apool{}'.format(len(links)), F.AveragePooling2D(8, 1, 0, False)))
        links.append(('fc{}'.format(len(links)), L.Linear(int(round(channel)), 10)))

        for name, f in links:
            if not name.startswith('_'):
                self.add_link(*(name, f))
        self.layers = links
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号