def test_pickle_gpu(self):
self.fs2.to_gpu()
fs2_serialized = pickle.dumps(self.fs2)
fs2_loaded = pickle.loads(fs2_serialized)
fs2_loaded.to_cpu()
self.fs2.to_cpu()
self.assertTrue((self.fs2.b.p.data == fs2_loaded.b.p.data).all())
self.assertTrue(
(self.fs2.fs1.a.p.data == fs2_loaded.fs1.a.p.data).all())
评论列表
文章目录