def bdsmm(sparse, dense):
"""
Batch dense-sparse matrix multiply
"""
if sparse.ndimension() > 2:
batch_size, n_rows, n_cols = sparse.size()
batch_assignment = sparse._indices()[0]
indices = sparse._indices()[1:].clone()
indices[0].add_(n_rows, batch_assignment)
indices[1].add_(n_cols, batch_assignment)
sparse_2d = sparse.__class__(indices, sparse._values(),
torch.Size((batch_size * n_rows, batch_size * n_cols)))
if dense.size(0) == 1:
dense = dense.repeat(batch_size, 1, 1)
dense_2d = dense.contiguous().view(batch_size * n_cols, -1)
res = torch.dsmm(sparse_2d, dense_2d)
res = res.view(batch_size, n_rows, -1)
return res
else:
return torch.dsmm(sparse, dense)
评论列表
文章目录