extmath.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:Parallel-SGD 作者: angadgill 项目源码 文件源码
def _fast_dot(A, B):
    if B.shape[0] != A.shape[A.ndim - 1]:  # check adopted from '_dotblas.c'
        raise ValueError

    if A.dtype != B.dtype or any(x.dtype not in (np.float32, np.float64)
                                 for x in [A, B]):
        warnings.warn('Falling back to np.dot. '
                      'Data must be of same type of either '
                      '32 or 64 bit float for the BLAS function, gemm, to be '
                      'used for an efficient dot operation. ',
                      NonBLASDotWarning)
        raise ValueError

    if min(A.shape) == 1 or min(B.shape) == 1 or A.ndim != 2 or B.ndim != 2:
        raise ValueError

    # scipy 0.9 compliant API
    dot = linalg.get_blas_funcs(['gemm'], (A, B))[0]
    A, trans_a = _impose_f_order(A)
    B, trans_b = _impose_f_order(B)
    return dot(alpha=1.0, a=A, b=B, trans_a=trans_a, trans_b=trans_b)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号