_util.py 文件源码

python
阅读 38 收藏 0 点赞 0 评论 0

项目:pydatalab 作者: googledatalab 项目源码 文件源码
def accuracy(logits, labels):
  """Calculates aggregated accuracy."""
  is_correct = tf.nn.in_top_k(logits, labels, 1)
  correct = tf.reduce_sum(tf.cast(is_correct, tf.int32))
  incorrect = tf.reduce_sum(tf.cast(tf.logical_not(is_correct), tf.int32))
  correct_count = tf.Variable(0, False)
  incorrect_count = tf.Variable(0, False)
  correct_count_update = tf.assign_add(correct_count, correct)
  incorrect_count_update = tf.assign_add(incorrect_count, incorrect)
  accuracy_op = tf.cast(correct_count, tf.float32) / tf.cast(
      correct_count + incorrect_count, tf.float32)
  return [correct_count_update, incorrect_count_update], accuracy_op
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号