model.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:tensorflow_ocr 作者: BowieHsu 项目源码 文件源码
def OHNM_single_image(scores, n_pos, neg_mask):
    """Online Hard Negative Mining.
    scores: the scores of being predicted as negative cls
    n_pos: the number of positive samples 
    neg_mask: mask of negative samples
    Return:
    the mask of selected negative samples.
    if n_pos == 0, no negative samples will be selected.
    """
    def has_pos():
        n_neg = n_pos * 3
        max_neg_entries = tf.reduce_sum(tf.cast(neg_mask, tf.int32))
        n_neg = tf.minimum(n_neg, max_neg_entries)
        n_neg = tf.cast(n_neg, tf.int32)
        neg_conf = tf.boolean_mask(scores, neg_mask)
        vals, _ = tf.nn.top_k(-neg_conf, k=n_neg)
        threshold = vals[-1]# a negtive value
        selected_neg_mask = tf.logical_and(neg_mask, scores <= -threshold)
        return tf.cast(selected_neg_mask, tf.float32)

    def no_pos():
        return tf.zeros_like(neg_mask, tf.float32)

    return tf.cond(n_pos > 0, has_pos, no_pos)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号