def test_serialization(self):
a = [torch.randn(5, 5).float() for i in range(2)]
b = [a[i % 2] for i in range(4)]
b += [a[0].storage()]
b += [a[0].storage()[1:4]]
for use_name in (False, True):
with tempfile.NamedTemporaryFile() as f:
handle = f if not use_name else f.name
torch.save(b, handle)
f.seek(0)
c = torch.load(handle)
self.assertEqual(b, c, 0)
self.assertTrue(isinstance(c[0], torch.FloatTensor))
self.assertTrue(isinstance(c[1], torch.FloatTensor))
self.assertTrue(isinstance(c[2], torch.FloatTensor))
self.assertTrue(isinstance(c[3], torch.FloatTensor))
self.assertTrue(isinstance(c[4], torch.FloatStorage))
c[0].fill_(10)
self.assertEqual(c[0], c[2], 0)
self.assertEqual(c[4], torch.FloatStorage(25).fill_(10), 0)
c[1].fill_(20)
self.assertEqual(c[1], c[3], 0)
self.assertEqual(c[4], c[5][1:4], 0)
评论列表
文章目录