def calculate_loss(self, predictions, labels, **unused_params):
with tf.name_scope("loss_xent_batch"):
batch_agreement = FLAGS.batch_agreement
epsilon = 10e-6
float_batch_size = float(FLAGS.batch_size)
float_labels = tf.cast(labels, tf.float32)
cross_entropy_loss = float_labels * tf.log(predictions + epsilon) + (
1 - float_labels) * tf.log(1 - predictions + epsilon)
cross_entropy_loss = tf.negative(cross_entropy_loss)
positive_predictions = predictions * float_labels + 1.0 - float_labels
min_pp = tf.reduce_min(positive_predictions)
negative_predictions = predictions * (1.0 - float_labels)
max_np = tf.reduce_max(negative_predictions)
# 1s that fall under 0s
false_negatives = tf.cast(predictions < max_np, tf.float32) * float_labels
num_fn = tf.reduce_sum(false_negatives)
center_fn = tf.reduce_sum(predictions * false_negatives) / num_fn
# 0s that grow over 1s
false_positives = tf.cast(predictions > min_pp, tf.float32) * (1.0 - float_labels)
num_fp = tf.reduce_sum(false_positives)
center_fp = tf.reduce_sum(predictions * false_positives) / num_fp
false_range = tf.maximum(epsilon, max_np - min_pp)
# for 1s that fall under 0s
weight_fn = tf.nn.sigmoid((center_fp - predictions) / false_range * 3.0) * (num_fp / float_batch_size) * false_negatives
# for 0s that grow over 1s
weight_fp = tf.nn.sigmoid((predictions - center_fn) / false_range * 3.0) * (num_fn / float_batch_size) * false_positives
weight = (weight_fn + weight_fp) * batch_agreement + 1.0
print weight
return tf.reduce_mean(tf.reduce_sum(weight * cross_entropy_loss, 1))
评论列表
文章目录