def build_loss(self):
upsampled_batch = self.get_output('output_logits')
annotation_batch=self.get_output('label')
class_labels = [i for i in range(cfg.NCLASSES)]
class_labels.append(255)
print("class_label: ", class_labels)
annotation_batch = tf.squeeze(annotation_batch, axis=3)
annotation_batch=tf.to_int32(annotation_batch)
valid_annotation_batch, valid_logits_batch = get_valid_logits_and_labels(logits_batch_tensor=upsampled_batch, \
annotation_batch_tensor=annotation_batch, \
class_labels=class_labels)
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=valid_logits_batch,
labels=valid_annotation_batch))
# add regularizer
if cfg.TRAIN.WEIGHT_DECAY > 0:
regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
loss = tf.add_n(regularization_losses) + loss
return loss
评论列表
文章目录