test_torch.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:pytorch-dist 作者: apaszke 项目源码 文件源码
def test_gesv(self):
        a = torch.Tensor(((6.80, -2.11,  5.66,  5.97,  8.23),
                        (-6.05, -3.30,  5.36, -4.44,  1.08),
                        (-0.45,  2.58, -2.70,  0.27,  9.04),
                        (8.32,  2.71,  4.35, -7.17,  2.14),
                        (-9.67, -5.14, -7.26,  6.08, -6.87))).t()
        b = torch.Tensor(((4.02,  6.19, -8.22, -7.57, -3.03),
                        (-1.56,  4.00, -8.67,  1.75,  2.86),
                        (9.81, -4.09, -4.57, -8.61,  8.99))).t()

        res1 = torch.gesv(b,a)[0]
        self.assertLessEqual(b.dist(torch.mm(a, res1)), 1e-12)
        ta = torch.Tensor()
        tb = torch.Tensor()
        res2 = torch.gesv(tb, ta, b, a)[0]
        res3 = torch.gesv(b, a, b, a)[0]
        self.assertEqual(res1, tb)
        self.assertEqual(res1, b)
        self.assertEqual(res1, res2)
        self.assertEqual(res1, res3)

        # test reuse
        res1 = torch.gesv(b, a)[0]
        ta = torch.Tensor()
        tb = torch.Tensor()
        torch.gesv(tb, ta, b, a)[0]
        self.assertEqual(res1, tb)
        torch.gesv(tb, ta, b, a)[0]
        self.assertEqual(res1, tb)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号