def weighted_loss(y_true, y_softmax_conv, weight):
"""Compute weighted loss function per pixel.
Loss = (1 - softmax(logits)) * targets * weight + softmax(logits) * (1 - targets) * weight
Argument:
y_true: [batch_size, depth, height, width, 1]
weight_map: [batch_size, depth, height, width, 1]
y_softmax_conv: [batch_size, depth, height, width, 2]
"""
y_true = tf.to_float(tf.reshape(y_true[..., 0], [-1]))
weight = tf.to_float(tf.reshape(weight[..., 0], [-1]))
y_conv = tf.to_float(tf.reshape(y_softmax_conv[..., 1], [-1]))
loss_pos = 1 / 2 * tf.pow((1 - y_conv), 2) * y_true * weight
loss_neg = 1 / 2 * tf.pow(y_conv, 2) * (1 - y_true) * weight
return tf.reduce_mean(loss_pos + loss_neg)
评论列表
文章目录