def _sqrt_matmul(self, x, transpose_x=False): chol = array_ops.matrix_band_part(self._chol, -1, 0) # tf.matmul is defined a * b return math_ops.matmul(chol, x, adjoint_b=transpose_x)