def OHNM_single_image(scores, n_pos, neg_mask):
"""Online Hard Negative Mining.
scores: the scores of being predicted as negative cls
n_pos: the number of positive samples
neg_mask: mask of negative samples
Return:
the mask of selected negative samples.
if n_pos == 0, no negative samples will be selected.
"""
def has_pos():
n_neg = n_pos * 3
max_neg_entries = tf.reduce_sum(tf.cast(neg_mask, tf.int32))
n_neg = tf.minimum(n_neg, max_neg_entries)
n_neg = tf.cast(n_neg, tf.int32)
neg_conf = tf.boolean_mask(scores, neg_mask)
vals, _ = tf.nn.top_k(-neg_conf, k=n_neg)
threshold = vals[-1]# a negtive value
selected_neg_mask = tf.logical_and(neg_mask, scores <= -threshold)
return tf.cast(selected_neg_mask, tf.float32)
def no_pos():
return tf.zeros_like(neg_mask, tf.float32)
return tf.cond(n_pos > 0, has_pos, no_pos)
评论列表
文章目录