def sparse_repeat(sparse, *repeat_sizes):
orig_ndim = sparse.ndimension()
new_ndim = len(repeat_sizes)
orig_nvalues = sparse._indices().size(1)
# Expand the number of dimensions to match repeat_sizes
indices = torch.cat([sparse._indices().new().resize_(new_ndim - orig_ndim, orig_nvalues).zero_(),
sparse._indices()])
values = sparse._values()
size = [1] * (new_ndim - orig_ndim) + list(sparse.size())
# Expand each dimension
new_indices = indices.new().resize_(indices.size(0), indices.size(1) * mul(*repeat_sizes)).zero_()
new_values = values.repeat(mul(*repeat_sizes))
new_size = [dim_size * repeat_size for dim_size, repeat_size in zip(size, repeat_sizes)]
# Fill in new indices
new_indices[:, :orig_nvalues].copy_(indices)
unit_size = orig_nvalues
for i in range(new_ndim)[::-1]:
repeat_size = repeat_sizes[i]
for j in range(1, repeat_size):
new_indices[:, unit_size * j:unit_size * (j + 1)].copy_(new_indices[:, :unit_size])
new_indices[i, unit_size * j:unit_size * (j + 1)] += j * size[i]
unit_size *= repeat_size
return sparse.__class__(new_indices, new_values, torch.Size(new_size))
评论列表
文章目录