def bin_stats(predictions: tf.Tensor, labels: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
"""
Calculate f1, precision and recall from binary classification expected and predicted values.
:param predictions: 2-d tensor (batch, predictions) of predicted 0/1 classes
:param labels: 2-d tensor (batch, labels) of expected 0/1 classes
:return: a tuple of batched (f1, precision and recall) values
"""
predictions = tf.cast(predictions, tf.int32)
labels = tf.cast(labels, tf.int32)
true_positives = tf.reduce_sum((predictions * labels), axis=1)
false_positives = tf.reduce_sum(tf.cast(tf.greater(predictions, labels), tf.int32), axis=1)
false_negatives = tf.reduce_sum(tf.cast(tf.greater(labels, predictions), tf.int32), axis=1)
recall = true_positives / (true_positives + false_negatives)
precision = true_positives / (true_positives + false_positives)
f1_score = 2 / (1 / precision + 1 / recall)
return f1_score, precision, recall
评论列表
文章目录