test_torch.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:pytorch-dist 作者: apaszke 项目源码 文件源码
def tset_potri(self):
        a=torch.Tensor(((6.80, -2.11,  5.66,  5.97,  8.23),
                        (-6.05, -3.30,  5.36, -4.44,  1.08),
                        (-0.45,  2.58, -2.70,  0.27,  9.04),
                        (8.32,  2.71,  4.35, -7.17,  2.14),
                        (-9.67, -5.14, -7.26,  6.08, -6.87))).t()

        # make sure 'a' is symmetric PSD
        a = a * a.t()

        # compute inverse directly
        inv0 = torch.inverse(a)

        # default case
        chol = torch.potrf(a)
        inv1 = torch.potri(chol)
        self.assertLessEqual(inv0.dist(inv1), 1e-12)

        # upper Triangular Test
        chol = torch.potrf(a, 'U')
        inv1 = torch.potri(chol, 'U')
        self.assertLessEqual(inv0.dist(inv1), 1e-12)

        # lower Triangular Test
        chol = torch.potrf(a, 'L')
        inv1 = torch.potri(chol, 'L')
        self.assertLessEqual(inv0.dist(inv1), 1e-12)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号