def _iqfov_via_solve(self, x):
"""Get the inverse quadratic form on vectors via a solve."""
# x^{-1} A^{-1} x
# 1. Convert x to a matrix, flipping all extra dimensions in `x` to the
# final dimension of x_matrix.
x_matrix = flip_vector_to_matrix(
x, self.batch_shape(), self.get_batch_shape())
# 2. Get x_whitened_matrix = A^{-1} x_matrix
soln_matrix = self.solve(x_matrix)
# 3. Reshape back to a vector.
soln = flip_matrix_to_vector(
soln_matrix, extract_batch_shape(x, 1), x.get_shape()[:-1])
# 4. Compute the dot product: x^T soln
result = math_ops.reduce_sum(x * soln, reduction_indices=[-1])
result.set_shape(x.get_shape()[:-1])
return result
评论列表
文章目录