operator_pd.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def _iqfov_via_sqrt_solve(self, x):
    """Get the inverse quadratic form on vectors via a sqrt_solve."""
    # x^{-1} A^{-1} x = || S^{-1}x ||^2,
    # where S is a square root of A (A = SS^T).
    # Steps:
    # 1. Convert x to a matrix, flipping all extra dimensions in `x` to the
    #    final dimension of x_matrix.
    x_matrix = flip_vector_to_matrix(
        x, self.batch_shape(), self.get_batch_shape())
    # 2. Get soln_matrix = S^{-1} x_matrix
    soln_matrix = self.sqrt_solve(x_matrix)
    # 3. Reshape back to a vector.
    soln = flip_matrix_to_vector(
        soln_matrix, extract_batch_shape(x, 1), x.get_shape()[:-1])
    # 4. L2 (batch) vector norm squared.
    result = math_ops.reduce_sum(
        math_ops.square(soln), reduction_indices=[-1])
    result.set_shape(x.get_shape()[:-1])
    return result
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号