def _batch_sqrt_solve(self, rhs): diag_mat = array_ops.expand_dims(self._diag, -1) return rhs / math_ops.sqrt(diag_mat)