def __init__(self, dtype='float'):
"""
Cast a torch.Tensor to a different type
Arguments
---------
dtype : string or torch.*Tensor literal or list of such
data type to which input(s) will be cast.
If list, it should be the same length as inputs.
"""
if isinstance(dtype, (list,tuple)):
dtypes = []
for dt in dtype:
if isinstance(dt, str):
if dt == 'byte':
dt = th.ByteTensor
elif dt == 'double':
dt = th.DoubleTensor
elif dt == 'float':
dt = th.FloatTensor
elif dt == 'int':
dt = th.IntTensor
elif dt == 'long':
dt = th.LongTensor
elif dt == 'short':
dt = th.ShortTensor
dtypes.append(dt)
self.dtype = dtypes
else:
if isinstance(dtype, str):
if dtype == 'byte':
dtype = th.ByteTensor
elif dtype == 'double':
dtype = th.DoubleTensor
elif dtype == 'float':
dtype = th.FloatTensor
elif dtype == 'int':
dtype = th.IntTensor
elif dtype == 'long':
dtype = th.LongTensor
elif dtype == 'short':
dtype = th.ShortTensor
self.dtype = dtype
评论列表
文章目录