net.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:chainer-cifar 作者: dsanno 项目源码 文件源码
def __call__(self, x):
        skip = False
        if chainer.config.train and self.skip_ratio > 0 and np.random.rand() < self.skip_ratio:
            skip = True
        sh, sw = self.conv1.stride
        c_out, c_in, kh, kw = self.conv1.W.data.shape
        b, c, hh, ww = x.data.shape
        if sh == 1 and sw == 1:
            shape_out = (b, c_out, hh, ww)
        else:
            hh = (hh + 2 - kh) // sh + 1
            ww = (ww + 2 - kw) // sw + 1
            shape_out = (b, c_out, hh, ww)
        h = x
        if x.data.shape != shape_out:
            xp = chainer.cuda.get_array_module(x.data)
            n, c, hh, ww = x.data.shape
            pad_c = shape_out[1] - c
            p = xp.zeros((n, pad_c, hh, ww), dtype=xp.float32)
            p = chainer.Variable(p)
            x = F.concat((p, x))
            if x.data.shape[2:] != shape_out[2:]:
                x = F.average_pooling_2d(x, 1, 2)
        if skip:
            return x
        h = self.bn1(self.conv1(h))
        if self.activation1 is not None:
            h = self.activation1(h)
        h = self.bn2(self.conv2(h))
        if not chainer.config.train:
            h = h * (1 - self.skip_ratio)
        if self.swapout:
            h = F.dropout(h) + F.dropout(x)
        else:
            h = h + x
        if self.activation2 is not None:
            return self.activation2(h)
        else:
            return h
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号