def calculate_loss(self, predictions, labels, topk=20, **unused_params):
with tf.name_scope("loss_xent_batch"):
batch_agreement = FLAGS.batch_agreement
epsilon = 10e-6
float_batch_size = float(FLAGS.batch_size)
topk_predictions, _ = tf.nn.top_k(predictions, k=20)
min_topk_predictions = tf.reduce_min(topk_predictions, axis=1, keep_dims=True)
topk_mask = tf.cast(predictions >= min_topk_predictions, dtype=tf.float32)
float_labels = tf.cast(labels, tf.float32)
cross_entropy_loss = float_labels * tf.log(predictions + epsilon) + (
1 - float_labels) * tf.log(1 - predictions + epsilon)
cross_entropy_loss = tf.negative(cross_entropy_loss)
# minimum positive predictions in topk
positive_predictions = (predictions * float_labels * topk_mask) + 1.0 - (float_labels * topk_mask)
min_pp = tf.reduce_min(positive_predictions)
# maximum negative predictions
negative_predictions = predictions * (1.0 - float_labels)
max_np = tf.reduce_max(negative_predictions)
# 1s that fall under top-k
false_negatives = tf.cast(predictions < min_topk_predictions, tf.float32) * float_labels
# 0s that grow over 1s in top-k
false_positives = tf.cast(predictions > min_pp, tf.float32) * (1.0 - float_labels) * topk_mask
weight = (false_negatives + false_positives) * batch_agreement + 1.0
weight = tf.stop_gradient(weight)
print weight
return tf.reduce_mean(tf.reduce_sum(weight * cross_entropy_loss, 1))
评论列表
文章目录