lossFunction.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:dwt 作者: min2209 项目源码 文件源码
def depthCELoss2(pred, gt, weight, ss, outputChannels=16):
    with tf.name_scope("depth_CE_loss"):
        pred = tf.reshape(pred, (-1, outputChannels))
        epsilon = tf.constant(value=1e-25)
        predSoftmax = tf.to_float(tf.nn.softmax(pred))

        gt = tf.one_hot(indices=tf.to_int32(tf.squeeze(tf.reshape(gt, (-1, 1)))), depth=outputChannels, dtype=tf.float32)
        ss = tf.to_float(tf.reshape(ss, (-1, 1)))
        weight = tf.to_float(tf.reshape(weight, (-1, 1)))

        crossEntropyScaling = tf.to_float([3.0, 3.0, 3.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])

        crossEntropy = -tf.reduce_sum(((1-gt)*tf.log(tf.maximum(1-predSoftmax, epsilon))
                                       + gt*tf.log(tf.maximum(predSoftmax, epsilon)))*ss*crossEntropyScaling*weight,
                                      reduction_indices=[1])

        crossEntropySum = tf.reduce_sum(crossEntropy, name="cross_entropy_sum")
        return crossEntropySum
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号