def accuracy(logits, targets_pl, one_hot=False):
targets = tf.to_int64(targets_pl)
if one_hot:
# compare the indices of the outputs. For a correct prediction they should be the same
correct_prediction = tf.equal(tf.arg_max(logits, 1), tf.arg_max(targets, 1), name='accuracy_equals_oh')
else:
# compare the indices of the outputs with the correct label which is a number here.
correct_prediction = tf.equal(tf.arg_max(logits, 1), targets, name='accuracy_equals')
accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float32'), name='accuracy_mean')
tf.summary.scalar('accuracy_mean', accuracy)
return accuracy
评论列表
文章目录