def to_float_tensor(x, copy=True):
"""
FloatTensor is the most used torch type, so we create a special method for it
"""
if torch.is_tensor(x):
assert isinstance(x, torch.FloatTensor)
return x
elif TC.is_variable(x):
x = TC.to_tensor(x)
assert isinstance(x, torch.FloatTensor)
return x
elif not TC.is_numpy(x):
x = np.array(x, copy=False)
x = np_cast(x, np.float32)
if copy:
return torch.FloatTensor(x)
else:
return torch.from_numpy(x)
评论列表
文章目录