def to_sparse(dense):
mask = dense.ne(0)
indices = mask.nonzero()
if indices.storage():
values = dense[mask]
else:
indices = indices.resize_(1, dense.ndimension()).zero_()
values = dense.new().resize_(1).zero_()
# Construct sparse tensor
klass = getattr(torch.sparse, dense.__class__.__name__)
res = klass(indices.t(), values, dense.size())
if dense.is_cuda:
res = res.cuda()
return res
评论列表
文章目录