def unwrap(tensor_or_variable, to_cpu=True, as_numpy=False):
if isinstance(tensor_or_variable, (list, tuple)):
return type(tensor_or_variable)([unwrap(_t, to_cpu=to_cpu, as_numpy=as_numpy)
for _t in tensor_or_variable])
elif isinstance(tensor_or_variable, Variable):
tensor = tensor_or_variable.data
elif torch.is_tensor(tensor_or_variable):
tensor = tensor_or_variable
elif isinstance(tensor_or_variable, np.ndarray):
return tensor_or_variable
elif isinstance(tensor_or_variable, (float, int)):
return tensor_or_variable
else:
raise NotUnwrappableError("Cannot unwrap a '{}'."
.format(type(tensor_or_variable).__name__))
# Transfer to CPU if required
if to_cpu:
with delayed_keyboard_interrupt():
tensor = tensor.cpu()
# Convert to numpy if required
if as_numpy:
return tensor.cpu().numpy()
else:
return tensor
评论列表
文章目录