def tensors_default_to(host):
"""
Context manager to temporarily use Cpu or Cuda tensors in Pytorch.
:param str host: Either "cuda" or "cpu".
"""
assert host in ('cpu', 'cuda'), host
old_module = torch.Tensor.__module__
name = torch.Tensor.__name__
new_module = 'torch.cuda' if host == 'cuda' else 'torch'
torch.set_default_tensor_type('{}.{}'.format(new_module, name))
try:
yield
finally:
torch.set_default_tensor_type('{}.{}'.format(old_module, name))
评论列表
文章目录