def _fast_dot(A, B):
if B.shape[0] != A.shape[A.ndim - 1]: # check adopted from '_dotblas.c'
raise ValueError
if A.dtype != B.dtype or any(x.dtype not in (np.float32, np.float64)
for x in [A, B]):
warnings.warn('Falling back to np.dot. '
'Data must be of same type of either '
'32 or 64 bit float for the BLAS function, gemm, to be '
'used for an efficient dot operation. ',
NonBLASDotWarning)
raise ValueError
if min(A.shape) == 1 or min(B.shape) == 1 or A.ndim != 2 or B.ndim != 2:
raise ValueError
# scipy 0.9 compliant API
dot = linalg.get_blas_funcs(['gemm'], (A, B))[0]
A, trans_a = _impose_f_order(A)
B, trans_b = _impose_f_order(B)
return dot(alpha=1.0, a=A, b=B, trans_a=trans_a, trans_b=trans_b)
评论列表
文章目录