pyramidal_residual_networks.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:pyramidal_residual_networks 作者: nutszebra 项目源码 文件源码
def __call__(self, x, train=False):
        h = self.conv1(x, train=train)
        for i in six.moves.range(len(self.strides)):
            for ii in six.moves.range(len(self.strides[i])):
                name = 'res_block{}_{}'.format(i, ii)
                h = self[name](h, train=train)
        batch, channels, height, width = h.data.shape
        h = F.reshape(F.average_pooling_2d(h, (height, width)), (batch, channels, 1, 1))
        return F.reshape(self.linear(h, train=train), (batch, self.category_num))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号