common.py 文件源码

python
阅读 39 收藏 0 点赞 0 评论 0

项目:pyro 作者: uber 项目源码 文件源码
def tensors_default_to(host):
    """
    Context manager to temporarily use Cpu or Cuda tensors in Pytorch.

    :param str host: Either "cuda" or "cpu".
    """
    assert host in ('cpu', 'cuda'), host
    old_module = torch.Tensor.__module__
    name = torch.Tensor.__name__
    new_module = 'torch.cuda' if host == 'cuda' else 'torch'
    torch.set_default_tensor_type('{}.{}'.format(new_module, name))
    try:
        yield
    finally:
        torch.set_default_tensor_type('{}.{}'.format(old_module, name))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号