def categorical_sample_logits(X): # https://github.com/tensorflow/tensorflow/issues/456 U = tf.random_uniform(tf.shape(X)) return argmax(X - tf.log(-tf.log(U)), axis=1)