detnet.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:social-scene-understanding 作者: cvlab-epfl 项目源码 文件源码
def det_net_loss(seg_masks_in, reg_masks_in,
                 seg_preds, reg_preds,
                 reg_loss_weight=10.0,
                 epsilon=1e-5):

  with tf.variable_scope('loss'):
    out_size = seg_preds.get_shape()[1:3]
    seg_masks_in_ds = tf.image.resize_images(seg_masks_in[:,:,:,tf.newaxis],
                                             out_size[0], out_size[1],
                                             tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    reg_masks_in_ds = tf.image.resize_images(reg_masks_in,
                                             out_size[0], out_size[1],
                                             tf.image.ResizeMethod.NEAREST_NEIGHBOR)

    # segmentation loss
    seg_masks_onehot = slim.one_hot_encoding(seg_masks_in_ds[:,:,:,0], 2)
    seg_loss = - tf.reduce_mean(seg_masks_onehot * tf.log(seg_preds + epsilon))

    # regression loss
    mask = tf.to_float(seg_masks_in_ds)
    reg_loss = tf.reduce_sum(mask * (reg_preds - reg_masks_in_ds)**2)
    reg_loss = reg_loss / (tf.reduce_sum(mask) + 1.0)

  return seg_loss + reg_loss_weight * reg_loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号