metrics.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:tensorport-template 作者: tensorport 项目源码 文件源码
def multi_label(prediction_batch, labels_batch, threshold=0.5, moving_average=True):
    with tf.variable_scope('metrics'):
        threshold_graph = tf.constant(threshold, name='threshold')
        zero_point_five = tf.constant(0.5)
        predicted_bool = tf.greater_equal(prediction_batch, threshold_graph)
        real_bool = tf.greater_equal(labels_batch, zero_point_five)
        return _metrics(predicted_bool, real_bool, moving_average)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号