training.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:tefla 作者: openAGI 项目源码 文件源码
def _setup_summaries(self):
        with tf.name_scope('summaries'):
            self.epoch_loss = tf.placeholder(
                tf.float32, shape=[], name="epoch_loss")

            # Training summaries
            tf.summary.scalar('learning rate', self.learning_rate,
                              collections=[TRAINING_EPOCH_SUMMARIES])
            tf.summary.scalar('training (cross entropy) loss', self.epoch_loss,
                              collections=[TRAINING_EPOCH_SUMMARIES])
            if len(self.inputs.get_shape()) == 4:
                summary.summary_image(self.inputs, 'inputs', max_images=10, collections=[
                                      TRAINING_BATCH_SUMMARIES])
            for key, val in self.training_end_points.iteritems():
                summary.summary_activation(val, name=key, collections=[
                                           TRAINING_BATCH_SUMMARIES])
            summary.summary_trainable_params(['scalar', 'histogram', 'norm'], collections=[
                                             TRAINING_BATCH_SUMMARIES])
            summary.summary_gradients(self.grads_and_vars, [
                                      'scalar', 'histogram', 'norm'], collections=[TRAINING_BATCH_SUMMARIES])

            # Validation summaries
            for key, val in self.validation_end_points.iteritems():
                summary.summary_activation(val, name=key, collections=[
                                           VALIDATION_BATCH_SUMMARIES])

            tf.summary.scalar('validation loss', self.epoch_loss,
                              collections=[VALIDATION_EPOCH_SUMMARIES])
            self.validation_metric_placeholders = []
            for metric_name, _ in self.validation_metrics_def:
                validation_metric = tf.placeholder(
                    tf.float32, shape=[], name=metric_name.replace(' ', '_'))
                self.validation_metric_placeholders.append(validation_metric)
                tf.summary.scalar(metric_name, validation_metric,
                                  collections=[VALIDATION_EPOCH_SUMMARIES])
            self.validation_metric_placeholders = tuple(
                self.validation_metric_placeholders)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号