test_autograd.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:pytorch 作者: ezyang 项目源码 文件源码
def test_keepdim_warning(self):
        torch.utils.backcompat.keepdim_warning.enabled = True
        x = Variable(torch.randn(3, 4), requires_grad=True)

        def run_backward(y):
            y_ = y
            if type(y) is tuple:
                y_ = y[0]
            # check that backward runs smooth
            y_.backward(y_.data.new(y_.size()).normal_())

        def keepdim_check(f):
            with warnings.catch_warnings(record=True) as w:
                warnings.simplefilter("always")
                y = f(x, 1)
                self.assertTrue(len(w) == 1)
                self.assertTrue(issubclass(w[-1].category, UserWarning))
                self.assertTrue("keepdim" in str(w[-1].message))
                run_backward(y)
                self.assertEqual(x.size(), x.grad.size())

                # check against explicit keepdim
                y2 = f(x, 1, keepdim=False)
                self.assertEqual(y, y2)
                run_backward(y2)

                y3 = f(x, 1, keepdim=True)
                if type(y3) == tuple:
                    y3 = (y3[0].squeeze(1), y3[1].squeeze(1))
                else:
                    y3 = y3.squeeze(1)
                self.assertEqual(y, y3)
                run_backward(y3)

        keepdim_check(torch.sum)
        keepdim_check(torch.prod)
        keepdim_check(torch.mean)
        keepdim_check(torch.max)
        keepdim_check(torch.min)
        keepdim_check(torch.mode)
        keepdim_check(torch.median)
        keepdim_check(torch.kthvalue)
        keepdim_check(torch.var)
        keepdim_check(torch.std)
        torch.utils.backcompat.keepdim_warning.enabled = False
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号