metrics.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:tensorport-template 作者: tensorport 项目源码 文件源码
def single_label(predictions_batch, one_hot_labels_batch, moving_average=True):
    with tf.variable_scope('metrics'):
        shape = predictions_batch.get_shape().as_list()
        batch_size, num_outputs = shape[0], shape[1]
        # get the most probable label
        predicted_batch = tf.argmax(predictions_batch, axis=1)
        real_label_batch = tf.argmax(one_hot_labels_batch, axis=1)

        # tp, tn, fp, fn
        predicted_bool = tf.cast(tf.one_hot(predicted_batch, depth=num_outputs), dtype=tf.bool)
        real_bool = tf.cast(tf.one_hot(real_label_batch, depth=num_outputs), dtype=tf.bool)
        d = _metrics(predicted_bool, real_bool, moving_average)

        # confusion matrix
        confusion_batch = tf.confusion_matrix(labels=real_label_batch, predictions=predicted_batch,
                                              num_classes=num_outputs)

        if moving_average:
            # calculate moving averages
            confusion_batch = tf.cast(confusion_batch, dtype=tf.float32)
            ema = tf.train.ExponentialMovingAverage(decay=0.9)
            update_op = ema.apply([confusion_batch])
            confusion_matrix = ema.average(confusion_batch)
            d['update_op'] = [d['update_op'], update_op]
        else:
            # accumulative
            confusion_matrix = tf.Variable(tf.zeros([num_outputs, num_outputs], dtype=tf.int32),
                                           name='confusion_matrix', trainable=False)
            confusion_matrix = tf.assign_add(confusion_matrix, confusion_batch)

    d['confusion_matrix'] = confusion_matrix
    return d
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号