def log_prob_from_logits(logits): return logits - tf.reduce_logsumexp(logits, axis=2, keep_dims=True)