def pseudo_inverse_sqrt2(svd, eps=1e-7):
"""half pseduo-inverse, accepting existing values"""
# zero threshold for eigenvalues
if svd.__class__.__name__=='SvdTuple':
(s, u, v) = (svd.s, svd.u, svd.v)
elif svd.__class__.__name__=='SvdWrapper':
(s, u, v) = (svd.s, svd.u, svd.v)
else:
assert False, "Unknown type"
si = tf.where(tf.less(s, eps), s, 1./tf.sqrt(s))
return u @ tf.diag(si) @ tf.transpose(v)
评论列表
文章目录