def test_newindex(self):
reference = self._consecutive((3, 3, 3))
# This relies on __index__() being correct - but we have separate tests for that
def checkPartialAssign(index):
reference = torch.zeros(3, 3, 3)
reference[index] = self._consecutive((3, 3, 3))[index]
self.assertEqual(reference[index], self._consecutive((3, 3, 3))[index], 0)
reference[index] = 0
self.assertEqual(reference, torch.zeros(3, 3, 3), 0)
checkPartialAssign(0)
checkPartialAssign(1)
checkPartialAssign(2)
checkPartialAssign((0, 1))
checkPartialAssign((1, 2))
checkPartialAssign((0, 2))
with self.assertRaises(RuntimeError):
reference[1, 1, 1, 1] = 1
with self.assertRaises(RuntimeError):
reference[1, 1, 1, (1, 1)] = 1
with self.assertRaises(RuntimeError):
reference[3, 3, 3, 3, 3, 3, 3, 3] = 1
评论列表
文章目录