def _sample(self, n_samples): samples = tf.random_gamma([n_samples], self.alpha, beta=1, dtype=self.dtype) return samples / tf.reduce_sum(samples, -1, keep_dims=True)