test_autograd.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:pytorch 作者: ezyang 项目源码 文件源码
def test_basic_op_grad_fallback(self):
        """Grad output might need to be reshaped to match the second argument."""
        x = Variable(torch.randn(4, 6), requires_grad=True)
        b = Variable(torch.rand(12, 1) + 1e-2, requires_grad=True)
        c = Variable(torch.rand(8, 1) + 1e-2, requires_grad=True)

        def y():
            # .mm() depends on the grad_output being of correct size
            return b.mm(Variable(torch.rand(1, 2) + 1e-2))

        def z():
            return c.mm(Variable(torch.rand(1, 3) + 1e-2))

        # suppress broadcastable warning
        with warnings.catch_warnings(record=True):
            (x + y()).sum().backward()
            (x - y()).sum().backward()
            (x * y()).sum().backward()
            (x / y()).sum().backward()
            (x.dist(y())).sum().backward()
            (x.lerp(y(), 0.5)).sum().backward()
            (x.max(y())).sum().backward()
            (x.min(y())).sum().backward()
            (x.masked_fill(y() < 0, 0.5)).sum().backward()
            (x.masked_scatter(Variable(y().data < 0.25), z())).sum().backward()
            (x.masked_select(Variable(y().data < 0.25))).sum().backward()
            (x.addcmul(1, y(), z())).sum().backward()
            (x.addcdiv(1, y(), z())).sum().backward()
            (x.abs() ** y()).sum().backward()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号