test_torch.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:pytorch-dist 作者: apaszke 项目源码 文件源码
def test_median(self):
        for size in (155, 156):
            x = torch.rand(size, size)
            x0 = x.clone()

            res1val, res1ind = torch.median(x)
            res2val, res2ind = torch.sort(x)
            ind = int(math.floor((size+1)/2) - 1)

            self.assertEqual(res2val.select(1, ind), res1val.select(1, 0), 0)
            self.assertEqual(res2val.select(1, ind), res1val.select(1, 0), 0)

            # Test use of result tensor
            res2val = torch.Tensor()
            res2ind = torch.LongTensor()
            torch.median(res2val, res2ind, x)
            self.assertEqual(res2val, res1val, 0)
            self.assertEqual(res2ind, res1ind, 0)

            # Test non-default dim
            res1val, res1ind = torch.median(x, 0)
            res2val, res2ind = torch.sort(x, 0)
            self.assertEqual(res1val[0], res2val[ind], 0)
            self.assertEqual(res1ind[0], res2ind[ind], 0)

            # input unchanged
            self.assertEqual(x, x0, 0)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号