def test_deepcopy(self):
from copy import deepcopy
a = torch.randn(5, 5)
b = torch.randn(5, 5)
c = a.view(25)
q = [a, [a.storage(), b.storage()], b, c]
w = deepcopy(q)
self.assertEqual(w[0], q[0], 0)
self.assertEqual(w[1][0], q[1][0], 0)
self.assertEqual(w[1][1], q[1][1], 0)
self.assertEqual(w[1], q[1], 0)
self.assertEqual(w[2], q[2], 0)
# Check that deepcopy preserves sharing
w[0].add_(1)
for i in range(a.numel()):
self.assertEqual(w[1][0][i], q[1][0][i] + 1)
self.assertEqual(w[3], c + 1)
w[2].sub_(1)
for i in range(a.numel()):
self.assertEqual(w[1][1][i], q[1][1][i] - 1)
评论列表
文章目录