read_lua_file.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:pytorch 作者: tylergenter 项目源码 文件源码
def make_tensor_reader(typename):
    python_class = get_python_class(typename)

    def read_tensor(reader, version):
        # source:
        # https://github.com/torch/torch7/blob/master/generic/Tensor.c#L1243
        ndim = reader.read_int()

        # read size:
        size = torch.LongStorage(reader.read_long_array(ndim))
        # read stride:
        stride = torch.LongStorage(reader.read_long_array(ndim))
        # storage offset:
        storage_offset = reader.read_long() - 1
        # read storage:
        storage = reader.read()

        if storage is None or ndim == 0 or len(size) == 0 or len(stride) == 0:
            # empty torch tensor
            return python_class()

        return python_class().set_(storage, storage_offset, torch.Size(size), tuple(stride))
    return read_tensor
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号