training.py 文件源码

python
阅读 37 收藏 0 点赞 0 评论 0

项目:tefla 作者: litan 项目源码 文件源码
def _setup_classification_predictions_and_loss(self):
        self.training_end_points = self.model(is_training=True, reuse=None)
        self.inputs = self.training_end_points['inputs']
        training_logits, self.training_predictions = self.training_end_points['logits'], self.training_end_points[
            'predictions']
        self.validation_end_points = self.model(is_training=False, reuse=True)
        self.validation_inputs = self.validation_end_points['inputs']
        validation_logits, self.validation_predictions = self.validation_end_points['logits'], \
                                                         self.validation_end_points[
                                                             'predictions']
        with tf.name_scope('predictions'):
            self.target = tf.placeholder(tf.int32, shape=(None,), name='target')
        with tf.name_scope('loss'):
            training_loss = tf.reduce_mean(
                tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits=training_logits, labels=self.target))

            self.validation_loss = tf.reduce_mean(
                tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits=validation_logits, labels=self.target))

            l2_loss = tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
            self.regularized_training_loss = training_loss + l2_loss * self.cnf.get('l2_reg', 0.0)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号