test_torch.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:pytorch-coriander 作者: hughperkins 项目源码 文件源码
def test_serialization(self):
        a = [torch.randn(5, 5).float() for i in range(2)]
        b = [a[i % 2] for i in range(4)]
        b += [a[0].storage()]
        b += [a[0].storage()[1:4]]
        b += [torch.arange(1, 11).int()]
        t1 = torch.FloatTensor().set_(a[0].storage()[1:4], 0, (3,), (1,))
        t2 = torch.FloatTensor().set_(a[0].storage()[1:4], 0, (3,), (1,))
        b += [(t1.storage(), t1.storage(), t2.storage())]
        b += [a[0].storage()[0:2]]
        for use_name in (False, True):
            with tempfile.NamedTemporaryFile() as f:
                handle = f if not use_name else f.name
                torch.save(b, handle)
                f.seek(0)
                c = torch.load(handle)
            self.assertEqual(b, c, 0)
            self.assertTrue(isinstance(c[0], torch.FloatTensor))
            self.assertTrue(isinstance(c[1], torch.FloatTensor))
            self.assertTrue(isinstance(c[2], torch.FloatTensor))
            self.assertTrue(isinstance(c[3], torch.FloatTensor))
            self.assertTrue(isinstance(c[4], torch.FloatStorage))
            c[0].fill_(10)
            self.assertEqual(c[0], c[2], 0)
            self.assertEqual(c[4], torch.FloatStorage(25).fill_(10), 0)
            c[1].fill_(20)
            self.assertEqual(c[1], c[3], 0)
            self.assertEqual(c[4], c[5][1:4], 0)

            # check that serializing the same storage view object unpickles
            # it as one object not two (and vice versa)
            views = c[7]
            self.assertEqual(views[0]._cdata, views[1]._cdata)
            self.assertEqual(views[0], views[2])
            self.assertNotEqual(views[0]._cdata, views[2]._cdata)

            rootview = c[8]
            self.assertEqual(rootview.data_ptr(), c[0].data_ptr())
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号