def _cls_mining(self, scores, status, hard_neg_ratio=3.0, scope=None):
"""
Positive classification loss and hard negative classificatin loss
ARGS
scores: [n, n_classes]
status: int [n] node or link matching status
RETURNS
pos_loss: []
n_pos: int []
hard_neg_loss: []
n_hard_neg: []
"""
with tf.variable_scope(scope or 'cls_mining'):
# positive classification loss
pos_mask = tf.equal(status, MATCH_STATUS_POS)
pos_scores = tf.boolean_mask(scores, pos_mask)
n_pos = tf.shape(pos_scores)[0]
pos_labels = tf.fill([n_pos], POS_LABEL)
pos_loss = tf.reduce_sum(tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=pos_scores, labels=pos_labels))
# hard negative classification loss
neg_mask = tf.equal(status, MATCH_STATUS_NEG)
neg_scores = tf.boolean_mask(scores, neg_mask)
n_neg = tf.shape(neg_scores)[0]
n_hard_neg = tf.cast(n_pos, tf.float32) * hard_neg_ratio
n_hard_neg = tf.minimum(n_hard_neg, tf.cast(n_neg, tf.float32))
n_hard_neg = tf.cast(n_hard_neg, tf.int32)
neg_prob = tf.nn.softmax(neg_scores)[:, NEG_LABEL]
# find the k examples with the least negative probabilities
_, hard_neg_indices = tf.nn.top_k(-neg_prob, k=n_hard_neg)
hard_neg_scores = tf.gather(neg_scores, hard_neg_indices)
hard_neg_labels = tf.fill([n_hard_neg], NEG_LABEL)
hard_neg_loss = tf.reduce_sum(tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=hard_neg_scores, labels=hard_neg_labels))
return pos_loss, n_pos, hard_neg_loss, n_hard_neg
评论列表
文章目录