def loss(y_true_pixel, y_pred_pixel,
y_true_link, y_pred_link,
training_mask):
'''
return pixel loss and link loss
add OHEM mode
'''
pixel_shape = tf.shape(y_pred_pixel)
pixel_label = tf.cast(tf.reshape(y_true_pixel,[pixel_shape[0],-1]), dtype = tf.int32)
pixel_pred = tf.reshape(y_pred_pixel, [pixel_shape[0],-1, 2])
pixel_scores = slim.softmax(pixel_pred)
pixel_neg_scores = pixel_scores[:,:,0]
pixel_pos_mask, pixel_neg_mask = get_pos_and_neg_masks(pixel_label)
pixel_selected_mask = OHNM_batch(14, pixel_neg_scores, pixel_pos_mask, pixel_neg_mask)
n_seg_pos = tf.reduce_sum(tf.cast(pixel_pos_mask, tf.float32))
# classification_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels = pixel_label, logits = pixel_pred))
# classification_loss *= 2
#cls_mining_loss_function
with tf.name_scope('ohem_pixel_loss'):
def has_pos():
pixel_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits = pixel_pred, labels = pixel_label)
return tf.reduce_sum(pixel_loss * pixel_selected_mask)/n_seg_pos
def no_pos():
return tf.constant(.0)
classification_loss = tf.cond(n_seg_pos > 0, has_pos, no_pos)
#link_pos and link_neg loss function
link_shape = tf.shape(y_pred_pixel)
total_link_loss = []
with tf.name_scope('link_loss'):
for i in range(8):
y_link = y_true_link[:,:,:,i]
pred_link = y_pred_link[:, :, :, 2 * i: 2 * (i + 1)]
link_label = tf.cast(tf.reshape(y_link, [-1]), dtype = tf.int32)
link_pred = tf.reshape(pred_link, [-1, 2])
link_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits = link_pred, labels = link_label)
link_pos_mask, link_neg_mask = get_pos_and_neg_masks(link_label)
W_pixel = tf.reshape(pixel_selected_mask, [-1])
W_link_pos = tf.cast(link_pos_mask, dtype = tf.float32) * W_pixel
W_link_neg = tf.cast(link_neg_mask, dtype = tf.float32) * W_pixel
link_pos_n = tf.reduce_sum(tf.cast(W_link_pos, dtype = tf.float32))
link_neg_n = tf.reduce_sum(tf.cast(W_link_neg, dtype = tf.float32))
link_pos_loss = tf.reduce_sum(link_loss * W_link_pos)/link_pos_n
link_neg_loss = tf.reduce_sum(link_loss * W_link_neg)/link_neg_n
total_link_loss.append(link_pos_loss + link_neg_loss)
weight_link_loss = tf.reduce_sum(total_link_loss)
tf.summary.scalar('classification_loss', classification_loss)
tf.summary.scalar('link_loss', weight_link_loss)
return weight_link_loss + 2 * classification_loss
评论列表
文章目录