test_torch.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:pytorch-dist 作者: apaszke 项目源码 文件源码
def test_mode(self):
        x = torch.range(1, SIZE * SIZE).clone().resize_(SIZE, SIZE)
        x[:2] = 1
        x[:,:2] = 1
        x0 = x.clone()

        # Pre-calculated results.
        res1val = torch.Tensor(SIZE, 1).fill_(1)
        # The indices are the position of the last appearance of the mode element.
        res1ind = torch.LongTensor(SIZE, 1).fill_(1)
        res1ind[0] = SIZE-1
        res1ind[1] = SIZE-1

        res2val, res2ind = torch.mode(x)

        self.assertEqual(res1val, res2val, 0)
        self.assertEqual(res1ind, res2ind, 0)

        # Test use of result tensor
        res2val = torch.Tensor()
        res2ind = torch.LongTensor()
        torch.mode(res2val, res2ind, x)
        self.assertEqual(res1val, res2val, 0)
        self.assertEqual(res1ind, res2ind, 0)

        # Test non-default dim
        res2val, res2ind = torch.mode(x, 0)
        self.assertEqual(res1val.view(1, SIZE), res2val, 0)
        self.assertEqual(res1ind.view(1, SIZE), res2ind, 0)

        # input unchanged
        self.assertEqual(x, x0, 0)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号