def _batch_matmul_gpu(a, b, out, transa=False, transb=False, transout=False):
a = _as_batch_mat(cuda.cupy.ascontiguousarray(a))
b = _as_batch_mat(cuda.cupy.ascontiguousarray(b))
trans_axis = (0, 2, 1)
if transout:
out = out.transpose(trans_axis)
needtrans, _ = _get_ld(out)
if needtrans == 1:
# (A B)^T = B^T A^T
a, b = b, a
transa, transb = not transb, not transa
out = out.transpose(trans_axis)
if transa:
a = a.transpose(trans_axis)
if transb:
b = b.transpose(trans_axis)
transa, lda = _get_ld(a)
transb, ldb = _get_ld(b)
transout, ldout = _get_ld(out)
la, n, ka = a.shape
lb, kb, m = b.shape
assert ka == kb
assert transout == 0 or ldout == 1
assert out.shape == (la, n, m)
ap = _mat_ptrs(a)
bp = _mat_ptrs(b)
outp = _mat_ptrs(out)
cuda.cublas.sgemmBatched(
cuda.Device().cublas_handle,
transa,
transb,
n, m, ka, 1.0,
ap.data.ptr, lda,
bp.data.ptr, ldb,
0.0, outp.data.ptr, ldout, la)
评论列表
文章目录