def test_serialization_backwards_compat(self):
a = [torch.arange(1 + i, 26 + i).view(5, 5).float() for i in range(2)]
b = [a[i % 2] for i in range(4)]
b += [a[0].storage()]
b += [a[0].storage()[1:4]]
DATA_URL = 'https://download.pytorch.org/test_data/legacy_serialized.pt'
data_dir = os.path.join(os.path.dirname(__file__), 'data')
test_file_path = os.path.join(data_dir, 'legacy_serialized.pt')
succ = download_file(DATA_URL, test_file_path)
if not succ:
warnings.warn(("Couldn't download the test file for backwards compatibility! "
"Tests will be incomplete!"), RuntimeWarning)
return
c = torch.load(test_file_path)
self.assertEqual(b, c, 0)
self.assertTrue(isinstance(c[0], torch.FloatTensor))
self.assertTrue(isinstance(c[1], torch.FloatTensor))
self.assertTrue(isinstance(c[2], torch.FloatTensor))
self.assertTrue(isinstance(c[3], torch.FloatTensor))
self.assertTrue(isinstance(c[4], torch.FloatStorage))
c[0].fill_(10)
self.assertEqual(c[0], c[2], 0)
self.assertEqual(c[4], torch.FloatStorage(25).fill_(10), 0)
c[1].fill_(20)
self.assertEqual(c[1], c[3], 0)
self.assertEqual(c[4], c[5][1:4], 0)
评论列表
文章目录