operator_pd_vdvt_update.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def _batch_sqrt_matmul(self, x, transpose_x=False):
    v = self._v
    m = self._operator
    d = self._diag_operator
    # The operators call the appropriate matmul/batch_matmul automatically.
    # We cannot override.
    # batch_matmul is defined as:  x * y, so adjoint_a and adjoint_b are the
    # ways to transpose the left and right.
    mx = m.matmul(x, transpose_x=transpose_x)
    vt_x = math_ops.matmul(v, x, adjoint_a=True, adjoint_b=transpose_x)
    d_vt_x = d.matmul(vt_x)
    v_d_vt_x = math_ops.matmul(v, d_vt_x)

    return mx + v_d_vt_x
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号