test_torch.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:pytorch 作者: pytorch 项目源码 文件源码
def test_slice(self):
        # TODO: remove the Variable wrapper once we merge Variable and Tensor
        from torch.autograd import Variable
        empty = Variable(torch.Tensor())
        x = Variable(torch.arange(0, 16).view(4, 4))
        self.assertEqual(x.slice(), x)
        self.assertEqual(x.slice(0, 4), x)
        # start and stop are clamped to the size of dim
        self.assertEqual(x.slice(0, 5), x)
        # if start >= stop then the result is empty
        self.assertEqual(x.slice(2, 1), empty)
        self.assertEqual(x.slice(2, 2), empty)
        # out of bounds is also empty
        self.assertEqual(x.slice(10, 12), empty)
        # additional correctness checks
        self.assertEqual(x.slice(0, 1).data.tolist(), [[0, 1, 2, 3]])
        self.assertEqual(x.slice(0, -3).data.tolist(), [[0, 1, 2, 3]])
        self.assertEqual(x.slice(-2, 3, dim=1).data.tolist(), [[2], [6], [10], [14]])
        self.assertEqual(x.slice(0, -1, 2).data.tolist(), [[0, 1, 2, 3], [8, 9, 10, 11]])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号