test_torch.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:pytorch-coriander 作者: hughperkins 项目源码 文件源码
def test_serialization_backwards_compat(self):
        a = [torch.arange(1 + i, 26 + i).view(5, 5).float() for i in range(2)]
        b = [a[i % 2] for i in range(4)]
        b += [a[0].storage()]
        b += [a[0].storage()[1:4]]
        DATA_URL = 'https://download.pytorch.org/test_data/legacy_serialized.pt'
        data_dir = os.path.join(os.path.dirname(__file__), 'data')
        test_file_path = os.path.join(data_dir, 'legacy_serialized.pt')
        succ = download_file(DATA_URL, test_file_path)
        if not succ:
            warnings.warn(("Couldn't download the test file for backwards compatibility! "
                           "Tests will be incomplete!"), RuntimeWarning)
            return
        c = torch.load(test_file_path)
        self.assertEqual(b, c, 0)
        self.assertTrue(isinstance(c[0], torch.FloatTensor))
        self.assertTrue(isinstance(c[1], torch.FloatTensor))
        self.assertTrue(isinstance(c[2], torch.FloatTensor))
        self.assertTrue(isinstance(c[3], torch.FloatTensor))
        self.assertTrue(isinstance(c[4], torch.FloatStorage))
        c[0].fill_(10)
        self.assertEqual(c[0], c[2], 0)
        self.assertEqual(c[4], torch.FloatStorage(25).fill_(10), 0)
        c[1].fill_(20)
        self.assertEqual(c[1], c[3], 0)
        self.assertEqual(c[4], c[5][1:4], 0)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号