test_nn.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:pytorch 作者: pytorch 项目源码 文件源码
def test_upsamplingLinear1d(self):
        m = nn.Upsample(size=4, mode='linear')
        in_t = torch.ones(1, 1, 2)
        out_t = m(Variable(in_t))
        self.assertEqual(torch.ones(1, 1, 4), out_t.data)

        input = Variable(torch.randn(1, 1, 2), requires_grad=True)
        gradcheck(lambda x: F.upsample(x, 4, mode='linear'), (input,))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号