test_torch.py 文件源码

python
阅读 38 收藏 0 点赞 0 评论 0

项目:pytorch-dist 作者: apaszke 项目源码 文件源码
def test_newindex(self):
        reference = self._consecutive((3, 3, 3))
        # This relies on __index__() being correct - but we have separate tests for that
        def checkPartialAssign(index):
            reference = torch.zeros(3, 3, 3)
            reference[index] = self._consecutive((3, 3, 3))[index]
            self.assertEqual(reference[index], self._consecutive((3, 3, 3))[index], 0)
            reference[index] = 0
            self.assertEqual(reference, torch.zeros(3, 3, 3), 0)

        checkPartialAssign(0)
        checkPartialAssign(1)
        checkPartialAssign(2)
        checkPartialAssign((0, 1))
        checkPartialAssign((1, 2))
        checkPartialAssign((0, 2))

        with self.assertRaises(RuntimeError):
            reference[1, 1, 1, 1] = 1
        with self.assertRaises(RuntimeError):
            reference[1, 1, 1, (1, 1)] = 1
        with self.assertRaises(RuntimeError):
            reference[3, 3, 3, 3, 3, 3, 3, 3] = 1
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号