def sqrt_solve(self, x):
"""Computes `solve(self, x)`.
Doesn't actually do the sqrt! Named as such to agree with API.
To compute (M + V D V.T), we use the the Woodbury matrix identity:
inv(M + V D V.T) = inv(M) - inv(M) V inv(C) V.T inv(M)
where,
C = inv(D) + V.T inv(M) V.
See: https://en.wikipedia.org/wiki/Woodbury_matrix_identity
Args:
x: `Tensor`
Returns:
inv_of_self_times_x: `Tensor`
"""
minv_x = linalg_ops.matrix_triangular_solve(self._m, x)
vt_minv_x = math_ops.matmul(self._v, minv_x, transpose_a=True)
cinv_vt_minv_x = linalg_ops.matrix_solve(
self._woodbury_sandwiched_term(), vt_minv_x)
v_cinv_vt_minv_x = math_ops.matmul(self._v, cinv_vt_minv_x)
minv_v_cinv_vt_minv_x = linalg_ops.matrix_triangular_solve(
self._m, v_cinv_vt_minv_x)
return minv_x - minv_v_cinv_vt_minv_x
bijector.py 文件源码
python
阅读 18
收藏 0
点赞 0
评论 0
评论列表
文章目录