def loss(c_fuse, s_fuse, labels):
"""Add L2Loss to all the trainable variables.
Add summary for "Loss" and "Loss/avg".
Args:
c_fuse: Contours output map from inference().
s_fuse: Segments output map from inference().
labels: Labels from distorted_inputs or inputs().
Returns:
Loss tensor of type float.
"""
# Calculate the average cross entropy loss across the batch.
# Split the labels tensor into contours and segments image tensors
# Each has shape [FLAGS.batch_size, 696, 520, 1]
contours_labels, segments_labels = tf.split(labels, 2, 3)
_add_cross_entropy(contours_labels, c_fuse, 'c')
_add_cross_entropy(segments_labels, s_fuse, 's')
return tf.add_n(tf.get_collection('losses'), name='total_loss')
评论列表
文章目录