losses.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:youtube-8m 作者: wangheda 项目源码 文件源码
def calculate_loss(self, predictions, labels, **unused_params):
    with tf.name_scope("loss_xent_batch"):
      batch_agreement = FLAGS.batch_agreement
      epsilon = 10e-6
      float_batch_size = float(FLAGS.batch_size)

      float_labels = tf.cast(labels, tf.float32)
      cross_entropy_loss = float_labels * tf.log(predictions + epsilon) + (
          1 - float_labels) * tf.log(1 - predictions + epsilon)
      cross_entropy_loss = tf.negative(cross_entropy_loss)

      positive_predictions = predictions * float_labels + 1.0 - float_labels
      min_pp = tf.reduce_min(positive_predictions)

      negative_predictions = predictions * (1.0 - float_labels)
      max_np = tf.reduce_max(negative_predictions)

      # 1s that fall under 0s
      false_negatives = tf.cast(predictions < max_np, tf.float32) * float_labels
      num_fn = tf.reduce_sum(false_negatives)
      center_fn = tf.reduce_sum(predictions * false_negatives) / num_fn

      # 0s that grow over 1s
      false_positives = tf.cast(predictions > min_pp, tf.float32) * (1.0 - float_labels)
      num_fp = tf.reduce_sum(false_positives)
      center_fp = tf.reduce_sum(predictions * false_positives) / num_fp

      false_range = tf.maximum(epsilon, max_np - min_pp)

      # for 1s that fall under 0s
      weight_fn = tf.nn.sigmoid((center_fp - predictions) / false_range * 3.0) * (num_fp / float_batch_size) * false_negatives
      # for 0s that grow over 1s
      weight_fp = tf.nn.sigmoid((predictions - center_fn) / false_range * 3.0) * (num_fn / float_batch_size) * false_positives

      weight = (weight_fn + weight_fp) * batch_agreement + 1.0
      print weight
      return tf.reduce_mean(tf.reduce_sum(weight * cross_entropy_loss, 1))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号