train.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:tf_rnnlm 作者: Ubiqus 项目源码 文件源码
def _run(self):
    m, mvalid, mtest = self.train_model, self.validation_model, self.test_model
    config = self.config
    data = self.data
    params = self.params

    init_op = tf.initialize_all_variables()
    with tf.Session() as session:
      session.run(init_op)

      print("Starting training from epoch %d using %s loss" % (config.epoch, m.loss_fct))

      while config.epoch <= config.max_max_epoch:
        i = config.epoch
        lr_decay = config.lr_decay ** max(i - config.max_epoch, 0.0)
        m.assign_lr(session, config.learning_rate * lr_decay)

        print("\nEpoch: %d Learning rate: %.3f" % (i, session.run(m.lr)))
        train_perplexity = run_epoch(session, m,
          data.train,
          eval_op=m.train_op,
          verbose=True,
          opIO=self.io,
          log_rate=params.log_rate,
          save_rate=params.save_rate)
        print("Epoch: %d Train Perplexity: %.3f" % (i, train_perplexity))

        print("Validation using %s loss" % mvalid.loss_fct)
        valid_perplexity = run_epoch(session, mvalid, data.valid)
        print("Epoch: %d Valid Perplexity: %.3f" % (i, valid_perplexity))

        config.step = 0
        config.epoch += 1
        config.save()

        self.io.save_checkpoint(session, "ep_%d.ckpt" % config.epoch)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号