test_jit.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:pytorch 作者: ezyang 项目源码 文件源码
def test_assign_traces(self):
        """Check that output Variables are assigned traces before they are saved."""
        @traceable
        class MyFn(Function):
            @staticmethod
            def forward(ctx, a):
                out = a * 2
                ctx.save_for_backward(out)
                return out

            @staticmethod
            def backward(ctx, grad_a):
                a, = ctx.saved_variables
                return a * grad_a

        x = Variable(torch.randn(10, 10), requires_grad=True)
        trace, out = torch.jit.trace(MyFn.apply, x, nderivs=1)
        out.sum().backward()
        torch._C._jit_pass_dce(trace)
        self.assertExpected(str(trace))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号