def categorical_max(logits, d): value = tf.argmax(logits - tf.reduce_max(logits, [1], keep_dims=True), axis=1) return tf.one_hot(value, d)