train.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:bit-rnn 作者: hqythu 项目源码 文件源码
def main(_):
    if not FLAGS.data_path:
        raise ValueError("Must set --data_path to PTB data directory")

    raw_data = reader.ptb_raw_data(FLAGS.data_path)
    train_data, valid_data, test_data, _ = raw_data

    config = get_config()
    eval_config = get_config()
    eval_config.batch_size = 1
    eval_config.num_steps = 1

    with tf.Graph().as_default(), tf.Session() as session:
        initializer = tf.uniform_unit_scaling_initializer()
        with tf.variable_scope("model", reuse=None, initializer=initializer):
            m = PTBModel(is_training=True, config=config)
        with tf.variable_scope("model", reuse=True, initializer=initializer):
            mvalid = PTBModel(is_training=False, config=config)
            mtest = PTBModel(is_training=False, config=eval_config)

        tf.global_variables_initializer().run()

        def get_learning_rate(epoch, config):
            base_lr = config.learning_rate
            if epoch <= config.nr_epoch_first_stage:
                return base_lr
            elif epoch <= config.nr_epoch_second_stage:
                return base_lr * 0.1
            else:
                return base_lr * 0.01

        for i in range(config.max_epoch):
            m.assign_lr(session, get_learning_rate(i, config))

            print("Epoch: %d Learning rate: %f"
                  % (i + 1, session.run(m.lr)))
            train_perplexity = run_epoch(
                session, m, train_data, m.train_op, verbose=True)
            print("Epoch: %d Train Perplexity: %.3f"
                  % (i + 1, train_perplexity))
            valid_perplexity = run_epoch(
                session, mvalid, valid_data, tf.no_op())
            print("Epoch: %d Valid Perplexity: %.3f"
                  % (i + 1, valid_perplexity))

        test_perplexity = run_epoch(
            session, mtest, test_data, tf.no_op())
        print("Test Perplexity: %.3f" % test_perplexity)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号