test_nn.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:pytorch 作者: pytorch 项目源码 文件源码
def test_upsamplingNearest2d(self):
        m = nn.Upsample(size=4, mode='nearest')
        in_t = torch.ones(1, 1, 2, 2)
        out_t = m(Variable(in_t))
        self.assertEqual(torch.ones(1, 1, 4, 4), out_t.data)

        input = Variable(torch.randn(1, 1, 2, 2), requires_grad=True)
        self.assertEqual(
            F.upsample(input, 4, mode='nearest'),
            F.upsample(input, scale_factor=2, mode='nearest'))
        gradcheck(lambda x: F.upsample(x, 4, mode='nearest'), [input])
        gradgradcheck(lambda x: F.upsample(x, 4, mode='nearest'), [input])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号