common.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:mglex 作者: fungs 项目源码 文件源码
def nandot(a, b):  # TODO: speed up, avoid copying data
    "A numpy.dot() replacement which treats (0*-Inf)==0 and works around BLAS NaN bugs in matrices."
    # important note: a contains zeros and b contains inf/-inf/nan, not the other way around

    # workaround for zero*-inf=nan in dot product (must be 0 according to 0^0=1 with probabilities)
    # 1) calculate dot product
    # 2) select nan entries
    # 3) re-calculate matrix entries where 0*inf = 0 using np.nansum()
    tmp = np.dot(a, b)
    indices = np.where(np.isnan(tmp))
    ri, ci = indices
    with np.errstate(invalid='ignore'):
        values = np.nansum(a[ri, :] * b[:, ci].T, axis=1)
    values[np.isnan(values)] = 0.0
    tmp[indices] = values
    return tmp
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号