def _chollogdet(L): """Log det of a cholesky, where L is (..., D, D).""" ldiag = pos(tf.matrix_diag_part(L)) # keep > 0, and no vanashing gradient logdet = 2. * tf.reduce_sum(tf.log(ldiag)) return logdet