test_autograd.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:pytorch 作者: tylergenter 项目源码 文件源码
def test_once_differentiable(self):
        class MyFunction(Function):

            @staticmethod
            def forward(ctx, tensor1, scalar, tensor2):
                ctx.scalar = scalar
                ctx.save_for_backward(tensor1, tensor2)
                return tensor1 + scalar * tensor2 + tensor1 * tensor2

            @staticmethod
            @once_differentiable
            def backward(ctx, grad_output):
                t1, t2 = ctx.saved_tensors
                # NOTE: self is the test case here
                self.assertTrue(torch.is_tensor(t1))
                self.assertTrue(torch.is_tensor(t2))
                self.assertTrue(torch.is_tensor(grad_output))
                return (grad_output + grad_output * t2, None,
                        grad_output * ctx.scalar + grad_output * t1)

        x, y = self._function_test(MyFunction)
        x_grad_desc = graph_desc(x.grad.grad_fn)
        y_grad_desc = graph_desc(y.grad.grad_fn)
        self.assertEqual(graph_desc(x.grad.grad_fn),
                         'Identity(Error(AccumulateGrad(), None, AccumulateGrad()))')
        self.assertEqual(graph_desc(y.grad.grad_fn),
                         'Identity(Error(AccumulateGrad(), None, AccumulateGrad()))')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号