def det_net_loss(seg_masks_in, reg_masks_in,
seg_preds, reg_preds,
reg_loss_weight=10.0,
epsilon=1e-5):
with tf.variable_scope('loss'):
out_size = seg_preds.get_shape()[1:3]
seg_masks_in_ds = tf.image.resize_images(seg_masks_in[:,:,:,tf.newaxis],
out_size[0], out_size[1],
tf.image.ResizeMethod.NEAREST_NEIGHBOR)
reg_masks_in_ds = tf.image.resize_images(reg_masks_in,
out_size[0], out_size[1],
tf.image.ResizeMethod.NEAREST_NEIGHBOR)
# segmentation loss
seg_masks_onehot = slim.one_hot_encoding(seg_masks_in_ds[:,:,:,0], 2)
seg_loss = - tf.reduce_mean(seg_masks_onehot * tf.log(seg_preds + epsilon))
# regression loss
mask = tf.to_float(seg_masks_in_ds)
reg_loss = tf.reduce_sum(mask * (reg_preds - reg_masks_in_ds)**2)
reg_loss = reg_loss / (tf.reduce_sum(mask) + 1.0)
return seg_loss + reg_loss_weight * reg_loss
评论列表
文章目录