python类Function()的实例源码

test_autograd.py 文件源码 项目:pytorch 作者: tylergenter 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def test_save_none_for_backward(self):
        test_case = self

        class MyFn(Function):

            def forward(self, input):
                self.save_for_backward(None, input, None)
                return input * input

            def backward(self, grad_output):
                n1, input, n2 = self.saved_tensors
                test_case.assertIsNone(n1)
                test_case.assertIsNone(n2)
                return 2 * input * grad_output

        x = Variable(torch.randn(5, 5), requires_grad=True)
        y = MyFn()(x)
        y.sum().backward()
        self.assertEqual(x.grad.data, 2 * x.data)
test_autograd.py 文件源码 项目:pytorch-coriander 作者: hughperkins 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def test_mark_non_differentiable(self):
        class MyFunction(Function):
            @staticmethod
            def forward(ctx, input):
                output = input > 0
                ctx.mark_non_differentiable(output)
                return output

            @staticmethod
            def backward(ctx, grad_output):
                return (grad_output * 0).type(torch.DoubleTensor)

        x = Variable(torch.randn(5, 5), requires_grad=True)
        mask = MyFunction.apply(x)
        self.assertFalse(mask.requires_grad)
        y = x.masked_fill(mask, 0)
        y.sum().backward()
test_autograd.py 文件源码 项目:pytorch-coriander 作者: hughperkins 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def test_save_none_for_backward(self):
        test_case = self

        class MyFn(Function):

            def forward(self, input):
                self.save_for_backward(None, input, None)
                return input * input

            def backward(self, grad_output):
                n1, input, n2 = self.saved_tensors
                test_case.assertIsNone(n1)
                test_case.assertIsNone(n2)
                return 2 * input * grad_output

        x = Variable(torch.randn(5, 5), requires_grad=True)
        y = MyFn()(x)
        y.sum().backward()
        self.assertEqual(x.grad.data, 2 * x.data)
test_jit.py 文件源码 项目:pytorch 作者: ezyang 项目源码 文件源码 阅读 31 收藏 0 点赞 0 评论 0
def test_assign_traces(self):
        """Check that output Variables are assigned traces before they are saved."""
        @traceable
        class MyFn(Function):
            @staticmethod
            def forward(ctx, a):
                out = a * 2
                ctx.save_for_backward(out)
                return out

            @staticmethod
            def backward(ctx, grad_a):
                a, = ctx.saved_variables
                return a * grad_a

        x = Variable(torch.randn(10, 10), requires_grad=True)
        trace, out = torch.jit.trace(MyFn.apply, x, nderivs=1)
        out.sum().backward()
        torch._C._jit_pass_dce(trace)
        self.assertExpected(str(trace))
test_autograd.py 文件源码 项目:pytorch 作者: ezyang 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def test_mark_non_differentiable(self):
        class MyFunction(Function):
            @staticmethod
            def forward(ctx, input):
                output = input > 0
                ctx.mark_non_differentiable(output)
                return output

            @staticmethod
            def backward(ctx, grad_output):
                return (grad_output * 0).type(torch.DoubleTensor)

        x = Variable(torch.randn(5, 5), requires_grad=True)
        mask = MyFunction.apply(x)
        self.assertFalse(mask.requires_grad)
        y = x.masked_fill(mask, 0)
        y.sum().backward()
test_autograd.py 文件源码 项目:pytorch 作者: ezyang 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def test_backward_device(self):
        # check that current device matches the variable's device
        device = [None]

        class Identity(torch.autograd.Function):
            @staticmethod
            def forward(ctx, x):
                return x.clone()

            @staticmethod
            def backward(ctx, grad_output):
                device[0] = torch.cuda.current_device()
                return grad_output.clone()

        v = Variable(torch.randn(1).cuda(1), requires_grad=True)
        Identity.apply(v).backward()
        self.assertEqual(device[0], 1)
test_autograd.py 文件源码 项目:pytorch 作者: ezyang 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def test_reentrant(self):
        y_data = torch.randn(2, 2)

        class Reenter(Function):
            @staticmethod
            def forward(ctx, x_data):
                ctx.x = Variable(x_data, requires_grad=True)
                ctx.y = Variable(y_data, requires_grad=True)
                ctx.output_var = ctx.x * ctx.y
                return ctx.output_var.data

            @staticmethod
            def backward(ctx, grad_output):
                ctx.output_var.sum().backward()
                return ctx.x.grad * grad_output

        x = Variable(torch.randn(2, 2), requires_grad=True)
        out = Reenter.apply(x)
        out.sum().backward()
        self.assertEqual(x.grad.data, y_data)
test_onnx.py 文件源码 项目:pytorch 作者: ezyang 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def test_symbolic_mismatch(self):
        class MyFun(Function):
            @staticmethod
            def symbolic(g, x):
                # The inside of this function should never be invoked, because
                # we will fail due to an argument mismatch first.
                assert False

            @staticmethod
            def forward(ctx, x, y):
                return x + y

        x = Variable(torch.randn(2, 2).fill_(1.0))
        y = Variable(torch.randn(2, 2).fill_(1.0))
        with self.assertRaisesRegex(TypeError, "occurred when translating MyFun"):
            export_to_string(FuncModule(MyFun().apply), (x, y))

    # TODO: Do an nn style test for these
test_jit.py 文件源码 项目:pytorch 作者: pytorch 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def test_assign_traces(self):
        """Check that output Variables are assigned traces before they are saved."""
        @traceable
        class MyFn(Function):
            @staticmethod
            def forward(ctx, a):
                out = a * 2
                ctx.save_for_backward(out)
                return out

            @staticmethod
            def backward(ctx, grad_a):
                a, = ctx.saved_variables
                return a * grad_a

        x = Variable(torch.randn(10, 10), requires_grad=True)
        trace, out = torch.jit.trace(MyFn.apply, x, nderivs=1)
        out.sum().backward()
        torch._C._jit_pass_dce(trace)
        self.assertExpectedTrace(trace)
test_jit.py 文件源码 项目:pytorch 作者: pytorch 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def test_inplace_check(self):
        class MyInplaceFn(Function):
            @staticmethod
            def forward(self, x):
                x.add_(1)
                self.mark_dirty(x)
                return x

            @staticmethod
            def backward(self, grad):
                return grad

        @torch.jit.compile(nderivs=0)
        def fn(x):
            return MyInplaceFn.apply(x)
        x = Variable(torch.randn(5, 5))
        fn(x)  # trace
        with self.assertRaisesRegex(RuntimeError, 'inplace MyInplaceFn'):
            fn(x)
test_autograd.py 文件源码 项目:pytorch 作者: pytorch 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def test_function_returns_input(self):
        class MyFunction(Function):
            @staticmethod
            def forward(ctx, x):
                return x

            @staticmethod
            def backward(ctx, grad):
                return grad * 2

        v = Variable(torch.ones(1), requires_grad=True)
        MyFunction.apply(v).backward()
        self.assertEqual(v.grad.data.tolist(), [2])

        v.grad.data.zero_()
        MyFunction.apply(v.clone()).backward()
        self.assertEqual(v.grad.data.tolist(), [2])
test_autograd.py 文件源码 项目:pytorch 作者: pytorch 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def test_mark_non_differentiable_mixed(self):
        class MyFunction(Function):
            @staticmethod
            def forward(ctx, input):
                a = input + 1
                b = input + 2
                ctx.mark_non_differentiable(a)
                return a, b

            @staticmethod
            def backward(ctx, grad_a, grad_b):
                self.assertTrue((grad_a == 0).all())
                self.assertTrue((grad_b == 1).all())
                return grad_b

        x = Variable(torch.randn(5, 5), requires_grad=True)
        a, b = MyFunction.apply(x)
        self.assertFalse(a.requires_grad)
        self.assertTrue(b.requires_grad)
        b.sum().backward()
        self.assertEqual(x.grad.data, torch.ones(5, 5))
test_autograd.py 文件源码 项目:pytorch 作者: pytorch 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def test_mark_non_differentiable_none(self):
        # This used to segfault because MyFunction would send back null
        # gradients to MulBackward, which is implemented in C++. C++
        # implemented functions expect incoming  grad_ouptuts to be non-null.
        class MyFunction(Function):
            @staticmethod
            def forward(ctx, input):
                output = input.clone()
                ctx.mark_non_differentiable(output)
                return output

            @staticmethod
            def backward(ctx, grad_output):
                return None

        x = Variable(torch.randn(5, 5), requires_grad=True)
        r = MyFunction.apply(x * x)
        (r * x).sum().backward()
test_autograd.py 文件源码 项目:pytorch 作者: pytorch 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def test_backward_device(self):
        # check that current device matches the variable's device
        device = [None]

        class Identity(torch.autograd.Function):
            @staticmethod
            def forward(ctx, x):
                return x.clone()

            @staticmethod
            def backward(ctx, grad_output):
                device[0] = torch.cuda.current_device()
                return grad_output.clone()

        v = Variable(torch.randn(1).cuda(1), requires_grad=True)
        Identity.apply(v).backward()
        self.assertEqual(device[0], 1)
test_autograd.py 文件源码 项目:pytorch 作者: pytorch 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def test_reentrant(self):
        y_data = torch.randn(2, 2)

        class Reenter(Function):
            @staticmethod
            def forward(ctx, x_data):
                ctx.x = Variable(x_data, requires_grad=True)
                ctx.y = Variable(y_data, requires_grad=True)
                ctx.output_var = ctx.x * ctx.y
                return ctx.output_var.data

            @staticmethod
            def backward(ctx, grad_output):
                ctx.output_var.sum().backward()
                return ctx.x.grad * grad_output

        x = Variable(torch.randn(2, 2), requires_grad=True)
        out = Reenter.apply(x)
        out.sum().backward(create_graph=True)
        self.assertEqual(x.grad.data, y_data)
test_autograd.py 文件源码 项目:pytorch 作者: pytorch 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def test_inplace_view_python(self):
        # in-place modifications of Python-autograd created view
        a = Variable(torch.randn(4, 4), requires_grad=True)
        b = Variable(torch.randn(2, 2), requires_grad=True)

        class PyAdd(torch.autograd.Function):
            @staticmethod
            def forward(ctx, x, y):
                ctx.mark_dirty(x)
                x.add_(y)
                return x

            @staticmethod
            def backward(ctx, grad):
                return grad, grad

        def func(root, b):
            x = root.clone()
            PyAdd.apply(x.narrow(1, 2, 2).narrow(0, 1, 2), b)
            PyAdd.apply(x.narrow(1, 0, 2).narrow(0, 1, 2), b)
            return x

        gradcheck(func, [a, b], raise_exception=True)
        go = Variable(torch.randn(a.size()), requires_grad=True)
        gradgradcheck(func, (a, b), (go,))
test_autograd.py 文件源码 项目:pytorch 作者: tylergenter 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def test_function(self):
        class MyFunction(Function):

            @staticmethod
            def forward(ctx, tensor1, scalar, tensor2):
                ctx.scalar = scalar
                ctx.save_for_backward(tensor1, tensor2)
                return tensor1 + scalar * tensor2 + tensor1 * tensor2

            @staticmethod
            def backward(ctx, grad_output):
                var1, var2 = ctx.saved_variables
                # NOTE: self is the test case here
                self.assertIsInstance(var1, Variable)
                self.assertIsInstance(var2, Variable)
                self.assertIsInstance(grad_output, Variable)
                return (grad_output + grad_output * var2, None,
                        grad_output * ctx.scalar + grad_output * var1)

        x, y = self._function_test(MyFunction)

        x_grad_desc = graph_desc(x.grad.grad_fn)
        y_grad_desc = graph_desc(y.grad.grad_fn)
        self.assertEqual(
            x_grad_desc,
            'Identity(AddBackward(ExpandBackward(AccumulateGrad()), '
            'MulBackward(ExpandBackward(AccumulateGrad()), AccumulateGrad())))')
        self.assertEqual(
            y_grad_desc,
            'Identity(AddBackward(MulConstantBackward(ExpandBackward(AccumulateGrad())), '
            'MulBackward(ExpandBackward(AccumulateGrad()), AccumulateGrad())))')
test_autograd.py 文件源码 项目:pytorch 作者: tylergenter 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def test_once_differentiable(self):
        class MyFunction(Function):

            @staticmethod
            def forward(ctx, tensor1, scalar, tensor2):
                ctx.scalar = scalar
                ctx.save_for_backward(tensor1, tensor2)
                return tensor1 + scalar * tensor2 + tensor1 * tensor2

            @staticmethod
            @once_differentiable
            def backward(ctx, grad_output):
                t1, t2 = ctx.saved_tensors
                # NOTE: self is the test case here
                self.assertTrue(torch.is_tensor(t1))
                self.assertTrue(torch.is_tensor(t2))
                self.assertTrue(torch.is_tensor(grad_output))
                return (grad_output + grad_output * t2, None,
                        grad_output * ctx.scalar + grad_output * t1)

        x, y = self._function_test(MyFunction)
        x_grad_desc = graph_desc(x.grad.grad_fn)
        y_grad_desc = graph_desc(y.grad.grad_fn)
        self.assertEqual(graph_desc(x.grad.grad_fn),
                         'Identity(Error(AccumulateGrad(), None, AccumulateGrad()))')
        self.assertEqual(graph_desc(y.grad.grad_fn),
                         'Identity(Error(AccumulateGrad(), None, AccumulateGrad()))')
test_autograd.py 文件源码 项目:pytorch 作者: tylergenter 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def test_hook_none(self):
        # WARNING: this is a test for autograd internals.
        # You should never have to use such things in your code.
        class NoneGradientFunction(Function):

            def forward(self, x, y):
                assert self.needs_input_grad[0]
                assert not self.needs_input_grad[1]
                return x, y

            def backward(self, grad_x, grad_y):
                return grad_x, None

        fn = NoneGradientFunction()
        was_called = [False]

        def hook(grad_input, grad_output):
            self.assertIsInstance(grad_input, tuple)
            self.assertIsInstance(grad_output, tuple)
            self.assertIsNotNone(grad_input[0])
            self.assertIsNone(grad_input[1])
            self.assertIsNotNone(grad_output[0])
            self.assertIsNotNone(grad_output[1])
            was_called[0] = True
        fn.register_hook(hook)

        x = Variable(torch.randn(5, 5), requires_grad=True)
        y = Variable(torch.randn(5, 5))
        sum(fn(x, y)).sum().backward()
        self.assertTrue(was_called[0])
test_autograd.py 文件源码 项目:pytorch 作者: tylergenter 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def test_gc_in_destructor(self):
        """
        Previously, if a Function destructor triggered a garbage collection,
        the Variable's tp_dealloc handler would get called twice leading to a
        segfault.
        """
        class CollectOnDelete(Function):

            def __del__(self):
                gc.collect()

        for i in range(10):
            Variable(torch.randn(10, 10), _grad_fn=CollectOnDelete())
test_autograd.py 文件源码 项目:pytorch 作者: tylergenter 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def test_too_many_grads(self):
        class MyFn(Function):

            def forward(self, input):
                return input

            def backward(self, grad_output):
                return grad_output, None, None

        x = Variable(torch.randn(5, 5), requires_grad=True)
        y = MyFn()(x)
        y.sum().backward()
        self.assertEqual(x.grad.data, x.data.clone().fill_(1))
test_autograd.py 文件源码 项目:pytorch 作者: tylergenter 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def test_dep_nograd(self):
        class F1(Function):

            def forward(self, input):
                out = torch.randn(input.size())
                self.mark_non_differentiable(out)
                return input, out

            def backward(self, grad_output, ignored):
                return grad_output

        class F2(Function):

            def forward(self, input, ignored):
                return input

            def backward(self, grad_output):
                return grad_output, None

        x = Variable(torch.randn(5), requires_grad=True)
        a, b = F1()(x)
        b = b + 1  # separate F1 from F2 by another op
        self.assertTrue(a.requires_grad)
        self.assertFalse(b.requires_grad)
        c = F2()(a, b)
        c.backward(torch.ones(c.size()))
        self.assertEqual(x.grad.data, torch.ones(x.size()))
test_autograd.py 文件源码 项目:pytorch-coriander 作者: hughperkins 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def test_function(self):
        class MyFunction(Function):

            @staticmethod
            def forward(ctx, tensor1, scalar, tensor2):
                ctx.scalar = scalar
                ctx.save_for_backward(tensor1, tensor2)
                return tensor1 + scalar * tensor2 + tensor1 * tensor2

            @staticmethod
            def backward(ctx, grad_output):
                var1, var2 = ctx.saved_variables
                # NOTE: self is the test case here
                self.assertIsInstance(var1, Variable)
                self.assertIsInstance(var2, Variable)
                self.assertIsInstance(grad_output, Variable)
                return (grad_output + grad_output * var2, None,
                        grad_output * ctx.scalar + grad_output * var1)

        x, y = self._function_test(MyFunction)

        x_grad_desc = graph_desc(x.grad.grad_fn)
        y_grad_desc = graph_desc(y.grad.grad_fn)
        self.assertEqual(
            x_grad_desc,
            'Identity(AddBackward(ExpandBackward(AccumulateGrad()), '
            'MulBackward(ExpandBackward(AccumulateGrad()), AccumulateGrad())))')
        self.assertEqual(
            y_grad_desc,
            'Identity(AddBackward(MulConstantBackward(ExpandBackward(AccumulateGrad())), '
            'MulBackward(ExpandBackward(AccumulateGrad()), AccumulateGrad())))')
test_autograd.py 文件源码 项目:pytorch-coriander 作者: hughperkins 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def test_once_differentiable(self):
        class MyFunction(Function):

            @staticmethod
            def forward(ctx, tensor1, scalar, tensor2):
                ctx.scalar = scalar
                ctx.save_for_backward(tensor1, tensor2)
                return tensor1 + scalar * tensor2 + tensor1 * tensor2

            @staticmethod
            @once_differentiable
            def backward(ctx, grad_output):
                t1, t2 = ctx.saved_tensors
                # NOTE: self is the test case here
                self.assertTrue(torch.is_tensor(t1))
                self.assertTrue(torch.is_tensor(t2))
                self.assertTrue(torch.is_tensor(grad_output))
                return (grad_output + grad_output * t2, None,
                        grad_output * ctx.scalar + grad_output * t1)

        x, y = self._function_test(MyFunction)
        x_grad_desc = graph_desc(x.grad.grad_fn)
        y_grad_desc = graph_desc(y.grad.grad_fn)
        self.assertEqual(graph_desc(x.grad.grad_fn),
                         'Identity(Error(AccumulateGrad(), None, AccumulateGrad()))')
        self.assertEqual(graph_desc(y.grad.grad_fn),
                         'Identity(Error(AccumulateGrad(), None, AccumulateGrad()))')
test_autograd.py 文件源码 项目:pytorch-coriander 作者: hughperkins 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def test_hook_none(self):
        # WARNING: this is a test for autograd internals.
        # You should never have to use such things in your code.
        class NoneGradientFunction(Function):

            def forward(self, x, y):
                assert self.needs_input_grad[0]
                assert not self.needs_input_grad[1]
                return x, y

            def backward(self, grad_x, grad_y):
                return grad_x, None

        fn = NoneGradientFunction()
        was_called = [False]

        def hook(grad_input, grad_output):
            self.assertIsInstance(grad_input, tuple)
            self.assertIsInstance(grad_output, tuple)
            self.assertIsNotNone(grad_input[0])
            self.assertIsNone(grad_input[1])
            self.assertIsNotNone(grad_output[0])
            self.assertIsNotNone(grad_output[1])
            was_called[0] = True
        fn.register_hook(hook)

        x = Variable(torch.randn(5, 5), requires_grad=True)
        y = Variable(torch.randn(5, 5))
        sum(fn(x, y)).sum().backward()
        self.assertTrue(was_called[0])
test_autograd.py 文件源码 项目:pytorch-coriander 作者: hughperkins 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def test_return_leaf(self):
        class Identity(Function):

            def forward(self, a, b):
                return a, a + b

            def backward(self, grad_a, grad_b):
                return grad_a + grad_b, grad_b

        class Inplace(InplaceFunction):

            def forward(self, a, b):
                self.mark_dirty(a)
                return a.add_(b), b + 2

            def backward(self, grad_a, grad_b):
                return grad_a, grad_a + grad_b

        x = Variable(torch.randn(5, 5), requires_grad=True)
        y = Variable(torch.randn(5, 5), requires_grad=True)

        q, p = Identity()(x, y)
        # Make sure hooks only receive grad from usage of q, not x.
        q.register_hook(
            lambda grad: self.assertEqual(grad.data, torch.ones(5, 5)))
        (q + p + x).sum().backward()
        self.assertEqual(x.grad.data, torch.ones(5, 5) * 3)
        self.assertEqual(y.grad.data, torch.ones(5, 5))
        del q, p  # these need to be freed, or next part will raise an error
test_autograd.py 文件源码 项目:pytorch-coriander 作者: hughperkins 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def test_too_many_grads(self):
        class MyFn(Function):

            def forward(self, input):
                return input

            def backward(self, grad_output):
                return grad_output, None, None

        x = Variable(torch.randn(5, 5), requires_grad=True)
        y = MyFn()(x)
        y.sum().backward()
        self.assertEqual(x.grad.data, x.data.clone().fill_(1))
test_autograd.py 文件源码 项目:pytorch-coriander 作者: hughperkins 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def test_dep_nograd(self):
        class F1(Function):

            def forward(self, input):
                out = torch.randn(input.size())
                self.mark_non_differentiable(out)
                return input, out

            def backward(self, grad_output, ignored):
                return grad_output

        class F2(Function):

            def forward(self, input, ignored):
                return input

            def backward(self, grad_output):
                return grad_output, None

        x = Variable(torch.randn(5), requires_grad=True)
        a, b = F1()(x)
        b = b + 1  # separate F1 from F2 by another op
        self.assertTrue(a.requires_grad)
        self.assertFalse(b.requires_grad)
        c = F2()(a, b)
        c.backward(torch.ones(c.size()))
        self.assertEqual(x.grad.data, torch.ones(x.size()))
test_jit.py 文件源码 项目:pytorch 作者: ezyang 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def test_legacy_fail(self):

        class MyLegacyFn(Function):
            def forward(self, x):
                return x

            def backward(self, grad_output):
                return grad_output

        x = Variable(torch.Tensor([0]), requires_grad=True)
        trace = torch._C._tracer_enter((x,), 0)
        self.assertRaisesRegex(RuntimeError, "MyLegacyFn", lambda: MyLegacyFn()(x))
        torch._C._tracer_exit((x,))
test_autograd.py 文件源码 项目:pytorch 作者: ezyang 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def test_function(self):
        class MyFunction(Function):

            @staticmethod
            def forward(ctx, tensor1, scalar, tensor2):
                ctx.scalar = scalar
                ctx.save_for_backward(tensor1, tensor2)
                return tensor1 + scalar * tensor2 + tensor1 * tensor2

            @staticmethod
            def backward(ctx, grad_output):
                var1, var2 = ctx.saved_variables
                # NOTE: self is the test case here
                self.assertIsInstance(var1, Variable)
                self.assertIsInstance(var2, Variable)
                self.assertIsInstance(grad_output, Variable)
                return (grad_output + grad_output * var2, None,
                        grad_output * ctx.scalar + grad_output * var1)

        x, y = self._function_test(MyFunction)

        x_grad_desc = graph_desc(x.grad.grad_fn)
        y_grad_desc = graph_desc(y.grad.grad_fn)
        self.assertEqual(
            x_grad_desc,
            'Identity(AddBackward(ExpandBackward(AccumulateGrad()), '
            'MulBackward(ExpandBackward(AccumulateGrad()), AccumulateGrad())))')
        self.assertEqual(
            y_grad_desc,
            'Identity(AddBackward(MulConstantBackward(ExpandBackward(AccumulateGrad())), '
            'MulBackward(ExpandBackward(AccumulateGrad()), AccumulateGrad())))')


问题


面经


文章

微信
公众号

扫码关注公众号