network.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:fcn 作者: ilovin 项目源码 文件源码
def build_loss(self):
        upsampled_batch = self.get_output('output_logits')
        annotation_batch=self.get_output('label')
        class_labels = [i for i in range(cfg.NCLASSES)]
        class_labels.append(255)
        print("class_label: ", class_labels)

        annotation_batch = tf.squeeze(annotation_batch, axis=3)
        annotation_batch=tf.to_int32(annotation_batch)
        valid_annotation_batch, valid_logits_batch = get_valid_logits_and_labels(logits_batch_tensor=upsampled_batch, \
                                                                                 annotation_batch_tensor=annotation_batch, \
                                                                                 class_labels=class_labels)

        loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=valid_logits_batch,
                                                                      labels=valid_annotation_batch))

        # add regularizer
        if cfg.TRAIN.WEIGHT_DECAY > 0:
            regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
            loss = tf.add_n(regularization_losses) + loss

        return loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号