def symsqrt(mat, eps=1e-7):
"""Symmetric square root."""
s, u, v = tf.svd(mat)
# sqrt is unstable around 0, just use 0 in such case
print("Warning, cutting off at eps")
si = tf.where(tf.less(s, eps), s, tf.sqrt(s))
return u @ tf.diag(si) @ tf.transpose(v)
评论列表
文章目录