def nandot(a, b): # TODO: speed up, avoid copying data
"A numpy.dot() replacement which treats (0*-Inf)==0 and works around BLAS NaN bugs in matrices."
# important note: a contains zeros and b contains inf/-inf/nan, not the other way around
# workaround for zero*-inf=nan in dot product (must be 0 according to 0^0=1 with probabilities)
# 1) calculate dot product
# 2) select nan entries
# 3) re-calculate matrix entries where 0*inf = 0 using np.nansum()
tmp = np.dot(a, b)
indices = np.where(np.isnan(tmp))
ri, ci = indices
with np.errstate(invalid='ignore'):
values = np.nansum(a[ri, :] * b[:, ci].T, axis=1)
values[np.isnan(values)] = 0.0
tmp[indices] = values
return tmp
评论列表
文章目录