def class_balanced_cross_entropy_loss_theoretical(output, label):
"""Theoretical version of the class balanced cross entropy loss to train the network (Produces unstable results)
Args:
output: Output of the network
label: Ground truth label
Returns:
Tensor that evaluates the loss
"""
output = tf.nn.sigmoid(output)
labels_pos = tf.cast(tf.greater(label, 0), tf.float32)
labels_neg = tf.cast(tf.less(label, 1), tf.float32)
num_labels_pos = tf.reduce_sum(labels_pos)
num_labels_neg = tf.reduce_sum(labels_neg)
num_total = num_labels_pos + num_labels_neg
loss_pos = tf.reduce_sum(tf.multiply(labels_pos, tf.log(output + 0.00001)))
loss_neg = tf.reduce_sum(tf.multiply(labels_neg, tf.log(1 - output + 0.00001)))
final_loss = -num_labels_neg / num_total * loss_pos - num_labels_pos / num_total * loss_neg
return final_loss
评论列表
文章目录