def log_dirichlet(self, size, scale=1.0): mu = tf.random_gamma([1], scale * np.ones(size).astype(np.float32)) mu = tf.log(mu / tf.reduce_sum(mu)) return mu