def bag_hinge_loss(config, preds, sent_mask, flip_sent_mask, hete_mask,
sent_trgt, sent_num):
""" HINGE LOSS:
DEFINED AS: MAX(0, M - MIN(SENT+) - MAX(SENT-))
THIS ONLY APPLIES TO HETE BAGS.
"""
flip_sent_trgt = \
tf.constant(1, shape=[config.batch_size,sent_num], dtype=config.data_type) - \
sent_trgt
pos_preds = preds + flip_sent_trgt + flip_sent_mask # [batch_size, sent_num]
neg_preds = preds * flip_sent_trgt * sent_mask # [batch_size, sent_num]
min_pos_pred = tf.reduce_min(pos_preds, 1)
# min_pos_pred = tf.Print(min_pos_pred, [min_pos_pred], message='min_pos_pred')
max_neg_pred = tf.reduce_max(neg_preds, 1)
# max_neg_pred = tf.Print(max_neg_pred, [max_neg_pred], message='max_neg_pred')
hinge_loss = hete_mask * tf.reduce_max(tf.pack(
[tf.constant(0, shape=[config.batch_size], dtype=config.data_type),
(0.20 - min_pos_pred + max_neg_pred)], axis=1), 1) # [batch_size]
# hinge_loss = tf.Print(hinge_loss, [hinge_loss], message='hinge_loss', summarize=20)
avg_hinge_loss = tf.reduce_sum(hinge_loss) / (tf.reduce_sum(hete_mask) + 1e-12)
return avg_hinge_loss
评论列表
文章目录