def _function_test(self, cls):
x = Variable(torch.randn(5, 5), requires_grad=True)
y = Variable(torch.randn(5, 5), requires_grad=True)
result = cls.apply(x, 2, y)
go = Variable(torch.ones(1), requires_grad=True)
result.sum().backward(go)
self.assertEqual(x.grad.data, y.data + torch.ones(5, 5))
self.assertEqual(y.grad.data, x.data + torch.ones(5, 5) * 2)
self.assertFalse(x.grad.volatile)
self.assertFalse(y.grad.volatile)
self.assertIsNotNone(x.grad.grad_fn)
self.assertIsNotNone(y.grad.grad_fn)
return x, y
评论列表
文章目录