__init__.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:gpytorch 作者: jrg365 项目源码 文件源码
def sparse_repeat(sparse, *repeat_sizes):
    orig_ndim = sparse.ndimension()
    new_ndim = len(repeat_sizes)
    orig_nvalues = sparse._indices().size(1)

    # Expand the number of dimensions to match repeat_sizes
    indices = torch.cat([sparse._indices().new().resize_(new_ndim - orig_ndim, orig_nvalues).zero_(),
                         sparse._indices()])
    values = sparse._values()
    size = [1] * (new_ndim - orig_ndim) + list(sparse.size())

    # Expand each dimension
    new_indices = indices.new().resize_(indices.size(0), indices.size(1) * mul(*repeat_sizes)).zero_()
    new_values = values.repeat(mul(*repeat_sizes))
    new_size = [dim_size * repeat_size for dim_size, repeat_size in zip(size, repeat_sizes)]

    # Fill in new indices
    new_indices[:, :orig_nvalues].copy_(indices)
    unit_size = orig_nvalues
    for i in range(new_ndim)[::-1]:
        repeat_size = repeat_sizes[i]
        for j in range(1, repeat_size):
            new_indices[:, unit_size * j:unit_size * (j + 1)].copy_(new_indices[:, :unit_size])
            new_indices[i, unit_size * j:unit_size * (j + 1)] += j * size[i]
        unit_size *= repeat_size

    return sparse.__class__(new_indices, new_values, torch.Size(new_size))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号