def sanitize(x: Any) -> Any: # pylint: disable=invalid-name,too-many-return-statements
"""
Sanitize turns PyTorch and Numpy types into basic Python types so they
can be serialized into JSON.
"""
if isinstance(x, (str, float, int, bool)):
# x is already serializable
return x
elif isinstance(x, torch.autograd.Variable):
return sanitize(x.data)
elif isinstance(x, torch._TensorBase): # pylint: disable=protected-access
# tensor needs to be converted to a list (and moved to cpu if necessary)
return x.cpu().tolist()
elif isinstance(x, numpy.ndarray):
# array needs to be converted to a list
return x.tolist()
elif isinstance(x, numpy.number):
# NumPy numbers need to be converted to Python numbers
return x.item()
elif isinstance(x, dict):
# Dicts need their values sanitized
return {key: sanitize(value) for key, value in x.items()}
elif isinstance(x, (list, tuple)):
# Lists and Tuples need their values sanitized
return [sanitize(x_i) for x_i in x]
else:
raise ValueError("cannot sanitize {} of type {}".format(x, type(x)))
评论列表
文章目录