net.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:chainer-cifar 作者: dsanno 项目源码 文件源码
def __call__(self, x):
        sh, sw = self.conv1_1.stride
        c_out, c_in, kh, kw = self.conv1_1.W.data.shape
        b, c, hh, ww = x.data.shape
        if sh == 1 and sw == 1:
            shape_out = (b, c_out, hh, ww)
        else:
            hh = (hh + 2 - kh) // sh + 1
            ww = (ww + 2 - kw) // sw + 1
            shape_out = (b, c_out, hh, ww)
        h = x
        if x.data.shape != shape_out:
            xp = chainer.cuda.get_array_module(x.data)
            n, c, hh, ww = x.data.shape
            pad_c = shape_out[1] - c
            p = xp.zeros((n, pad_c, hh, ww), dtype=xp.float32)
            x = F.concat((p, x))
            if x.data.shape[2:] != shape_out[2:]:
                x = F.average_pooling_2d(x, 1, 2)
        h1 = self.bn1_1(self.conv1_1(h))
        h2 = self.bn2_1(self.conv2_1(h))
        if self.activation1 is not None:
            h1 = self.activation1(h1)
            h2 = self.activation1(h2)
        h1 = self.bn1_2(self.conv1_2(h1))
        h2 = self.bn2_2(self.conv2_2(h2))
        h = shake_shake(h1, h2) + x
        if self.activation2 is not None:
            return self.activation2(h)
        else:
            return h
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号