test_torch.py 文件源码

python
阅读 41 收藏 0 点赞 0 评论 0

项目:pytorch-coriander 作者: hughperkins 项目源码 文件源码
def test_serialization_map_location(self):
        DATA_URL = 'https://download.pytorch.org/test_data/gpu_tensors.pt'
        data_dir = os.path.join(os.path.dirname(__file__), 'data')
        test_file_path = os.path.join(data_dir, 'gpu_tensors.pt')
        succ = download_file(DATA_URL, test_file_path)
        if not succ:
            warnings.warn(
                "Couldn't download the test file for map_location! "
                "Tests will be incomplete!", RuntimeWarning)
            return

        def map_location(storage, loc):
            return storage

        tensor = torch.load(test_file_path, map_location=map_location)
        self.assertEqual(type(tensor), torch.FloatTensor)
        self.assertEqual(tensor, torch.FloatTensor([[1.0, 2.0], [3.0, 4.0]]))

        tensor = torch.load(test_file_path, map_location={'cuda:0': 'cpu'})
        self.assertEqual(type(tensor), torch.FloatTensor)
        self.assertEqual(tensor, torch.FloatTensor([[1.0, 2.0], [3.0, 4.0]]))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号