train.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:tf_rnnlm 作者: Ubiqus 项目源码 文件源码
def _build_graph(self):
    config = self.config
    config.fast_test = False
    eval_config = Config(clone=config)
    eval_config.batch_size = 1
    initializer = self.model_initializer
    with tf.name_scope("Train"):
        with tf.variable_scope("Model", reuse=False, initializer=initializer):
          self.train_model = self.Model(config=config, is_training=True, loss_fct=self.loss_fct)
        tf.summary.scalar("Training Loss", self.train_model.cost)
        tf.summary.scalar("Learning Rate", self.train_model.lr)

        with tf.name_scope("Valid"):
          with tf.variable_scope("Model", reuse=True, initializer=initializer):
            self.validation_model = self.Model(config=config, is_training=False, loss_fct="softmax")
          tf.summary.scalar("Validation Loss", self.validation_model.cost)

    with tf.name_scope("Test"):
      with tf.variable_scope("Model", reuse=True, initializer=initializer):
        self.test_model = self.Model(config=eval_config, is_training=False)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号