bin.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:cxflow-tensorflow 作者: Cognexa 项目源码 文件源码
def bin_stats(predictions: tf.Tensor, labels: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
    """
    Calculate f1, precision and recall from binary classification expected and predicted values.

    :param predictions: 2-d tensor (batch, predictions) of predicted 0/1 classes
    :param labels: 2-d tensor (batch, labels) of expected 0/1 classes
    :return: a tuple of batched (f1, precision and recall) values
    """
    predictions = tf.cast(predictions, tf.int32)
    labels = tf.cast(labels, tf.int32)

    true_positives = tf.reduce_sum((predictions * labels), axis=1)
    false_positives = tf.reduce_sum(tf.cast(tf.greater(predictions, labels), tf.int32), axis=1)
    false_negatives = tf.reduce_sum(tf.cast(tf.greater(labels, predictions), tf.int32), axis=1)

    recall = true_positives / (true_positives + false_negatives)
    precision = true_positives / (true_positives + false_positives)
    f1_score = 2 / (1 / precision + 1 / recall)

    return f1_score, precision, recall
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号