def log_prob_from_logits(logits): """Softmax function.""" return logits - tf.reduce_logsumexp(logits, keep_dims=True)