matmul.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:chainer-deconv 作者: germanRos 项目源码 文件源码
def _batch_matmul_gpu(a, b, out, transa=False, transb=False, transout=False):
    a = _as_batch_mat(cuda.cupy.ascontiguousarray(a))
    b = _as_batch_mat(cuda.cupy.ascontiguousarray(b))
    trans_axis = (0, 2, 1)
    if transout:
        out = out.transpose(trans_axis)
    needtrans, _ = _get_ld(out)
    if needtrans == 1:
        # (A B)^T = B^T A^T
        a, b = b, a
        transa, transb = not transb, not transa
        out = out.transpose(trans_axis)
    if transa:
        a = a.transpose(trans_axis)
    if transb:
        b = b.transpose(trans_axis)

    transa, lda = _get_ld(a)
    transb, ldb = _get_ld(b)
    transout, ldout = _get_ld(out)
    la, n, ka = a.shape
    lb, kb, m = b.shape

    assert ka == kb
    assert transout == 0 or ldout == 1
    assert out.shape == (la, n, m)

    ap = _mat_ptrs(a)
    bp = _mat_ptrs(b)
    outp = _mat_ptrs(out)
    cuda.cublas.sgemmBatched(
        cuda.Device().cublas_handle,
        transa,
        transb,
        n, m, ka, 1.0,
        ap.data.ptr, lda,
        bp.data.ptr, ldb,
        0.0, outp.data.ptr, ldout, la)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号