def yoloclassloss(y_true, y_pred, t):
lo = K.square(y_true-y_pred)
value_if_true = lamda_class*(lo)
value_if_false = K.zeros_like(y_true)
loss1 = tf.select(t, value_if_true, value_if_false)
# only extract predicted class value at obj location
cat = K.sum(tf.select(t, y_pred, K.zeros_like(y_pred)), axis=1)
# check valid class value
objsum = K.sum(y_true, axis=1)
# if objsum > 0.5 , means it contain some valid obj(may be 1,2.. objs)
isobj = K.greater(objsum, 0.5)
# only extract class value at obj location
valid_cat = tf.select(isobj, cat, K.zeros_like(cat))
# prevent div 0
ave_cat = tf.select(K.greater(K.sum(objsum),0.5), K.sum(valid_cat) / K.sum(objsum) , -1)
return K.mean(loss1), ave_cat
评论列表
文章目录