torch_utils.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:inferno 作者: inferno-pytorch 项目源码 文件源码
def unwrap(tensor_or_variable, to_cpu=True, as_numpy=False):
    if isinstance(tensor_or_variable, (list, tuple)):
        return type(tensor_or_variable)([unwrap(_t, to_cpu=to_cpu, as_numpy=as_numpy)
                                         for _t in tensor_or_variable])
    elif isinstance(tensor_or_variable, Variable):
        tensor = tensor_or_variable.data
    elif torch.is_tensor(tensor_or_variable):
        tensor = tensor_or_variable
    elif isinstance(tensor_or_variable, np.ndarray):
        return tensor_or_variable
    elif isinstance(tensor_or_variable, (float, int)):
        return tensor_or_variable
    else:
        raise NotUnwrappableError("Cannot unwrap a '{}'."
                                  .format(type(tensor_or_variable).__name__))
    # Transfer to CPU if required
    if to_cpu:
        with delayed_keyboard_interrupt():
            tensor = tensor.cpu()
    # Convert to numpy if required
    if as_numpy:
        return tensor.cpu().numpy()
    else:
        return tensor
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号