def _batch_sqrt_matmul(self, x, transpose_x=False):
v = self._v
m = self._operator
d = self._diag_operator
# The operators call the appropriate matmul/batch_matmul automatically.
# We cannot override.
# batch_matmul is defined as: x * y, so adjoint_a and adjoint_b are the
# ways to transpose the left and right.
mx = m.matmul(x, transpose_x=transpose_x)
vt_x = math_ops.matmul(v, x, adjoint_a=True, adjoint_b=transpose_x)
d_vt_x = d.matmul(vt_x)
v_d_vt_x = math_ops.matmul(v, d_vt_x)
return mx + v_d_vt_x
评论列表
文章目录