def test_serialization_map_location(self):
DATA_URL = 'https://download.pytorch.org/test_data/gpu_tensors.pt'
data_dir = os.path.join(os.path.dirname(__file__), 'data')
test_file_path = os.path.join(data_dir, 'gpu_tensors.pt')
succ = download_file(DATA_URL, test_file_path)
if not succ:
warnings.warn(
"Couldn't download the test file for map_location! "
"Tests will be incomplete!", RuntimeWarning)
return
def map_location(storage, loc):
return storage
tensor = torch.load(test_file_path, map_location=map_location)
self.assertEqual(type(tensor), torch.FloatTensor)
self.assertEqual(tensor, torch.FloatTensor([[1.0, 2.0], [3.0, 4.0]]))
tensor = torch.load(test_file_path, map_location={'cuda:0': 'cpu'})
self.assertEqual(type(tensor), torch.FloatTensor)
self.assertEqual(tensor, torch.FloatTensor([[1.0, 2.0], [3.0, 4.0]]))
评论列表
文章目录