def _iqfov_via_sqrt_solve(self, x):
"""Get the inverse quadratic form on vectors via a sqrt_solve."""
# x^{-1} A^{-1} x = || S^{-1}x ||^2,
# where S is a square root of A (A = SS^T).
# Steps:
# 1. Convert x to a matrix, flipping all extra dimensions in `x` to the
# final dimension of x_matrix.
x_matrix = flip_vector_to_matrix(
x, self.batch_shape(), self.get_batch_shape())
# 2. Get soln_matrix = S^{-1} x_matrix
soln_matrix = self.sqrt_solve(x_matrix)
# 3. Reshape back to a vector.
soln = flip_matrix_to_vector(
soln_matrix, extract_batch_shape(x, 1), x.get_shape()[:-1])
# 4. L2 (batch) vector norm squared.
result = math_ops.reduce_sum(
math_ops.square(soln), reduction_indices=[-1])
result.set_shape(x.get_shape()[:-1])
return result
评论列表
文章目录