def _sqrt_matmul(self, x, transpose_x=False):
v = self._v
m = self._operator
d = self._diag_operator
# The operators call the appropriate matmul/batch_matmul automatically. We
# cannot override.
# matmul is defined as: a * b, so transpose_a, transpose_b are used.
# transpose the left and right.
mx = m.matmul(x, transpose_x=transpose_x)
vt_x = math_ops.matmul(v, x, transpose_a=True, transpose_b=transpose_x)
d_vt_x = d.matmul(vt_x)
v_d_vt_x = math_ops.matmul(v, d_vt_x)
return mx + v_d_vt_x
评论列表
文章目录