def tower_acc(logit, labels): correct_pred = tf.equal(tf.argmax(logit, 1), labels) accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) return accuracy