def _batch_sqrt_matmul(self, x, transpose_x=False): if transpose_x: x = array_ops.matrix_transpose(x) diag_mat = array_ops.expand_dims(self._diag, -1) return math_ops.sqrt(diag_mat) * x