def make_tensor_reader(typename):
python_class = get_python_class(typename)
def read_tensor(reader, version):
# source:
# https://github.com/torch/torch7/blob/master/generic/Tensor.c#L1243
ndim = reader.read_int()
# read size:
size = torch.LongStorage(reader.read_long_array(ndim))
# read stride:
stride = torch.LongStorage(reader.read_long_array(ndim))
# storage offset:
storage_offset = reader.read_long() - 1
# read storage:
storage = reader.read()
if storage is None or ndim == 0 or len(size) == 0 or len(stride) == 0:
# empty torch tensor
return python_class()
return python_class().set_(storage, storage_offset, torch.Size(size), tuple(stride))
return read_tensor
评论列表
文章目录