detregionloss.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:VGG 作者: jackfan00 项目源码 文件源码
def yoloclassloss(y_true, y_pred, t):
        lo = K.square(y_true-y_pred)
        value_if_true = lamda_class*(lo)
        value_if_false = K.zeros_like(y_true)
        loss1 = tf.select(t, value_if_true, value_if_false)
    # only extract predicted class value at obj location
    cat = K.sum(tf.select(t, y_pred, K.zeros_like(y_pred)), axis=1)
    # check valid class value
    objsum = K.sum(y_true, axis=1)
    # if objsum > 0.5 , means it contain some valid obj(may be 1,2.. objs)
    isobj = K.greater(objsum, 0.5)
    # only extract class value at obj location
    valid_cat = tf.select(isobj, cat, K.zeros_like(cat))
    # prevent div 0
    ave_cat = tf.select(K.greater(K.sum(objsum),0.5), K.sum(valid_cat) / K.sum(objsum) , -1)
    return K.mean(loss1), ave_cat
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号