tf_model.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:TF-Net 作者: Jorba123 项目源码 文件源码
def accuracy(logits, targets_pl, one_hot=False):
    targets = tf.to_int64(targets_pl)

    if one_hot:
        # compare the indices of the outputs. For a correct prediction they should be the same
        correct_prediction = tf.equal(tf.arg_max(logits, 1), tf.arg_max(targets, 1), name='accuracy_equals_oh')
    else:
        # compare the indices of the outputs with the correct label which is a number here.
        correct_prediction = tf.equal(tf.arg_max(logits, 1), targets, name='accuracy_equals')
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float32'), name='accuracy_mean')
    tf.summary.scalar('accuracy_mean', accuracy)
    return accuracy
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号