def multi_label(prediction_batch, labels_batch, threshold=0.5, moving_average=True):
with tf.variable_scope('metrics'):
threshold_graph = tf.constant(threshold, name='threshold')
zero_point_five = tf.constant(0.5)
predicted_bool = tf.greater_equal(prediction_batch, threshold_graph)
real_bool = tf.greater_equal(labels_batch, zero_point_five)
return _metrics(predicted_bool, real_bool, moving_average)
metrics.py 文件源码
python
阅读 24
收藏 0
点赞 0
评论 0
评论列表
文章目录