test_torch.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:pytorch-dist 作者: apaszke 项目源码 文件源码
def test_range(self):
        res1 = torch.range(0, 1)
        res2 = torch.Tensor()
        torch.range(res2, 0, 1)
        self.assertEqual(res1, res2, 0)

        # Check range for non-contiguous tensors.
        x = torch.zeros(2, 3)
        torch.range(x.narrow(1, 1, 2), 0, 3)
        res2 = torch.Tensor(((0, 0, 1), (0, 2, 3)))
        self.assertEqual(x, res2, 1e-16)

        # Check negative
        res1 = torch.Tensor((1, 0))
        res2 = torch.Tensor()
        torch.range(res2, 1, 0, -1)
        self.assertEqual(res1, res2, 0)

        # Equal bounds
        res1 = torch.ones(1)
        res2 = torch.Tensor()
        torch.range(res2, 1, 1, -1)
        self.assertEqual(res1, res2, 0)
        torch.range(res2, 1, 1, 1)
        self.assertEqual(res1, res2, 0)

        # FloatTensor
        res1 = torch.range(torch.FloatTensor(), 0.6, 0.9, 0.1)
        self.assertEqual(res1.size(0), 4)
        res1 = torch.range(torch.FloatTensor(), 1, 10, 0.3)
        self.assertEqual(res1.size(0), 31)

        # DoubleTensor
        res1 = torch.range(torch.DoubleTensor(), 0.6, 0.9, 0.1)
        self.assertEqual(res1.size(0), 4)
        res1 = torch.range(torch.DoubleTensor(), 1, 10, 0.3)
        self.assertEqual(res1.size(0), 31)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号