def blurred_cross_entropy(output, target, filter_size=11, sampling_range=3.5, pixel_weights=None):
"""
Apply a Gaussian smoothing filter to the target probabilities (i.e. the one-hot
representation of target) and compute the cross entropy loss between softmax(output)
and the blurred target probabilities.
:param output: A rank-4 or rank-5 tensor with shape=(samples, [sequence_position,] x, y, num_classes)
representing the network input of the output layer (not activated)
:param target: A rank-3 or rank-4 tensor with shape=(samples, [sequence_position,] x, y) representing
the target labels. It must contain int values in 0..num_classes-1.
:param filter_size: A length-2 list of int specifying the size of the Gaussian filter that will be
applied to the target probabilities.
:param pixel_weights: A rank-3 or rank-4 tensor with shape=(samples, [sequence_position,] x, y)
representing factors, that will be applied to the loss of the corresponding pixel. This can be
e.g. used to void certain pixels by weighting them to 0, i.e. suppress their error induction.
:return: A scalar operation representing the blurred cross entropy loss.
"""
# convert target to one-hot
output_shape = output.shape.as_list()
one_hot = tf.one_hot(target, output_shape[-1], dtype=tf.float32)
if (len(output_shape) > 4):
one_hot = tf.reshape(one_hot, [np.prod(output_shape[:-3])] + output_shape[-3:])
# blur target probabilities
#gauss_filter = weight_gauss_conv2d(filter_size + [output_shape[-1], 1])
#blurred_target = tf.nn.depthwise_conv2d(one_hot, gauss_filter, [1, 1, 1, 1], 'SAME')
blurred_target = gaussian_blur(one_hot, filter_size, sampling_range)
if (len(output_shape) > 4):
blurred_target = tf.reshape(blurred_target, output_shape)
# compute log softmax predictions and cross entropy
log_pred = output - tf.reduce_logsumexp(output, axis=[len(output_shape) - 1], keep_dims=True)
# Apply pixel-wise weighting
if pixel_weights is not None:
log_pred *= pixel_weights
cross_entropy = -tf.reduce_sum(blurred_target * log_pred, axis=[len(output_shape)-1])
if pixel_weights is not None:
loss = tf.reduce_sum(cross_entropy) / tf.reduce_sum(pixel_weights)
else:
loss = tf.reduce_mean(cross_entropy)
return loss
评论列表
文章目录