def sqrt_log_abs_det(self):
"""Computes (log o abs o det)(X) for matrix X.
Doesn't actually do the sqrt! Named as such to agree with API.
To compute det(M + V D V.T), we use the matrix determinant lemma:
det(Tril + V D V.T) = det(C) det(D) det(M)
where C is defined as in `_inverse`, ie,
C = inv(D) + V.T inv(M) V.
See: https://en.wikipedia.org/wiki/Matrix_determinant_lemma
Returns:
log_abs_det: `Tensor`.
"""
log_det_c = math_ops.log(math_ops.abs(
linalg_ops.matrix_determinant(self._woodbury_sandwiched_term())))
# Reduction is ok because we always prepad inputs to this class.
log_det_m = math_ops.reduce_sum(math_ops.log(math_ops.abs(
array_ops.matrix_diag_part(self._m))), reduction_indices=[-1])
return log_det_c + 2. * self._d.sqrt_log_abs_det() + log_det_m
bijector.py 文件源码
python
阅读 23
收藏 0
点赞 0
评论 0
评论列表
文章目录