utils.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:main_loop_tf 作者: fvisin 项目源码 文件源码
def apply_loss(labels, net_out, loss_fn, weight_decay, is_training,
               return_mean_loss=False, mask_voids=True):
    '''Applies the user-specified loss function and returns the loss

    Note:
        SoftmaxCrossEntropyWithLogits expects labels NOT to be one-hot
        and net_out to be one-hot.
    '''

    cfg = gflags.cfg

    if mask_voids and len(cfg.void_labels):
        # TODO Check this
        print('Masking the void labels')
        mask = tf.not_equal(labels, cfg.void_labels)
        labels *= tf.cast(mask, 'int32')  # void_class --> 0 (random class)
        # Train loss
        loss = loss_fn(labels=labels,
                       logits=tf.reshape(net_out, [-1, cfg.nclasses]))
        mask = tf.cast(mask, 'float32')
        loss *= mask
    else:
        # Train loss
        loss = loss_fn(labels=labels,
                       logits=tf.reshape(net_out, [-1, cfg.nclasses]))

    if is_training:
        loss = apply_l2_penalty(loss, weight_decay)

    # Return the mean loss (over pixels *and* batches)
    if return_mean_loss:
        if mask_voids and len(cfg.void_labels):
            return tf.reduce_sum(loss) / tf.reduce_sum(mask)
        else:
            return tf.reduce_mean(loss)
    else:
        return loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号