dgmm.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:pyinn 作者: szagoruyko 项目源码 文件源码
def cublas_dgmm(A, x, out=None):
    if out is not None:
        assert out.is_contiguous() and out.size() == A.size()
    else:
        out = A.new(A.size())
    assert x.dim() == 1
    assert x.numel() == A.size(-1) or x.numel() == A.size(0)
    assert A.type() == x.type() == out.type()
    assert A.is_contiguous()

    if not isinstance(A, (torch.cuda.FloatTensor, torch.cuda.DoubleTensor)):
        if x.numel() == A.size(-1):
            return A.mm(torch.diag(x), out=out.view_as(A))
        else:
            return torch.diag(x).mm(A, out=out.view_as(A))
    else:
        if x.numel() == A.size(-1):
            m, n =  A.size(-1), A.numel() // A.size(-1)
            mode = 'l'
            # A.mm(x.diag(), out=out)
            # return out
        elif x.numel() == A.size(0):
            n, m = A.size(0), A.numel() // A.size(0)
            mode = 'r'
            # if A.stride(0) == 1:
            #     mode = 'l'
            #     n, m = m, n
            # x.diag().mm(A, out=out)
            # return out
        lda, ldc = m, m
        incx = 1
        handle = torch.cuda.current_blas_handle()
        stream = torch.cuda.current_stream()._as_parameter_
        from skcuda import cublas
        cublas.cublasSetStream(handle, stream)
        args = [handle, mode, m, n, A.data_ptr(), lda, x.data_ptr(), incx, out.data_ptr(), ldc]
        if isinstance(A, torch.cuda.FloatTensor):
            cublas.cublasSdgmm(*args)
        elif isinstance(A, torch.cuda.DoubleTensor):
            cublas.cublasDdgmm(*args)
        return out
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号