def loss(self, img_batch, label_batch):
"""Create the network, run inference on the input batch and compute loss.
Args:
input_batch: batch of pre-processed images.
Returns:
Pixel-wise softmax loss.
"""
raw_output = self._create_network(tf.cast(img_batch, tf.float32), keep_prob=tf.constant(0.5))
prediction = tf.reshape(raw_output, [-1, n_classes])
# Need to resize labels and convert using one-hot encoding.
label_batch = self.prepare_label(label_batch, tf.stack(raw_output.get_shape()[1:3]))
gt = tf.reshape(label_batch, [-1, n_classes])
# Pixel-wise softmax loss.
loss = tf.nn.softmax_cross_entropy_with_logits(logits=prediction, labels=gt)
reduced_loss = tf.reduce_mean(loss)
return reduced_loss
评论列表
文章目录