def pixel_wise_cross_entropy_loss_weighted(logits, labels, class_weights):
'''
Weighted cross entropy loss, with a weight per class
:param logits: Network output before softmax
:param labels: Ground truth masks
:param class_weights: A list of the weights for each class
:return: weighted cross entropy loss
'''
n_class = len(class_weights)
flat_logits = tf.reshape(logits, [-1, n_class])
flat_labels = tf.reshape(labels, [-1, n_class])
class_weights = tf.constant(np.array(class_weights, dtype=np.float32))
weight_map = tf.multiply(flat_labels, class_weights)
weight_map = tf.reduce_sum(weight_map, axis=1)
loss_map = tf.nn.softmax_cross_entropy_with_logits(logits=flat_logits, labels=flat_labels)
weighted_loss = tf.multiply(loss_map, weight_map)
loss = tf.reduce_mean(weighted_loss)
return loss
评论列表
文章目录