def _batch_matmul(self, x, transpose_x=False):
# tf.matmul is defined x * y, so "y" is on the right, not "x".
chol = array_ops.matrix_band_part(self._chol, -1, 0)
chol_times_x = math_ops.matmul(
chol, x, adjoint_a=True, adjoint_b=transpose_x)
return math_ops.matmul(chol, chol_times_x)
评论列表
文章目录