model_group.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:answer-triggering 作者: jiez-osu 项目源码 文件源码
def bag_hinge_loss(config, preds, sent_mask, flip_sent_mask, hete_mask,
                   sent_trgt, sent_num):
  """ HINGE LOSS:
      DEFINED AS: MAX(0, M - MIN(SENT+) - MAX(SENT-))
      THIS ONLY APPLIES TO HETE BAGS.
  """
  flip_sent_trgt = \
      tf.constant(1, shape=[config.batch_size,sent_num], dtype=config.data_type) - \
      sent_trgt
  pos_preds = preds + flip_sent_trgt + flip_sent_mask # [batch_size, sent_num]
  neg_preds = preds * flip_sent_trgt * sent_mask # [batch_size, sent_num]
  min_pos_pred = tf.reduce_min(pos_preds, 1)
  # min_pos_pred = tf.Print(min_pos_pred, [min_pos_pred], message='min_pos_pred')
  max_neg_pred = tf.reduce_max(neg_preds, 1)
  # max_neg_pred = tf.Print(max_neg_pred, [max_neg_pred], message='max_neg_pred')

  hinge_loss = hete_mask * tf.reduce_max(tf.pack(
      [tf.constant(0, shape=[config.batch_size], dtype=config.data_type),
       (0.20 - min_pos_pred + max_neg_pred)], axis=1), 1) # [batch_size]
  # hinge_loss = tf.Print(hinge_loss, [hinge_loss], message='hinge_loss', summarize=20)

  avg_hinge_loss = tf.reduce_sum(hinge_loss) / (tf.reduce_sum(hete_mask) + 1e-12)
  return avg_hinge_loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号