def class_balanced_binary_class_cross_entropy(pred, label, name='cross_entropy_loss'):
"""
The class-balanced cross entropy loss for binary classification,
as in `Holistically-Nested Edge Detection
<http://arxiv.org/abs/1504.06375>`_.
:param pred: size: b x ANYTHING. the predictions in [0,1].
:param label: size: b x ANYTHING. the ground truth in {0,1}.
:returns: class-balanced binary classification cross entropy loss
"""
z = batch_flatten(pred)
y = tf.cast(batch_flatten(label), tf.float32)
count_neg = tf.reduce_sum(1. - y)
count_pos = tf.reduce_sum(y)
beta = count_neg / (count_neg + count_pos)
eps = 1e-8
loss_pos = -beta * tf.reduce_mean(y * tf.log(tf.abs(z) + eps), 1)
loss_neg = (1. - beta) * tf.reduce_mean((1. - y) * tf.log(tf.abs(1. - z) + eps), 1)
cost = tf.sub(loss_pos, loss_neg)
cost = tf.reduce_mean(cost, name=name)
return cost
评论列表
文章目录