__init__.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:gpytorch 作者: jrg365 项目源码 文件源码
def bdsmm(sparse, dense):
    """
    Batch dense-sparse matrix multiply
    """
    if sparse.ndimension() > 2:
        batch_size, n_rows, n_cols = sparse.size()
        batch_assignment = sparse._indices()[0]
        indices = sparse._indices()[1:].clone()
        indices[0].add_(n_rows, batch_assignment)
        indices[1].add_(n_cols, batch_assignment)
        sparse_2d = sparse.__class__(indices, sparse._values(),
                                     torch.Size((batch_size * n_rows, batch_size * n_cols)))

        if dense.size(0) == 1:
            dense = dense.repeat(batch_size, 1, 1)
        dense_2d = dense.contiguous().view(batch_size * n_cols, -1)
        res = torch.dsmm(sparse_2d, dense_2d)
        res = res.view(batch_size, n_rows, -1)
        return res
    else:
        return torch.dsmm(sparse, dense)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号