ptb_word_lm.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:Language-Model-2016 作者: feizhihui 项目源码 文件源码
def main():
    # ?? ???,???,???
    raw_data = reader.ptb_raw_data('/home/feizhihui/MyData/dataset/PTB/')
    train_data, valid_data, test_data, _ = raw_data
    # ??????config{small,medium,large,or test}
    # ??2?????,?????config??????
    config = Config()
    eval_config = Config()
    eval_config.batch_size = 1
    eval_config.num_steps = 1

    with tf.Graph().as_default(), tf.Session() as session:
        # ??????????
        initializer = tf.random_uniform_initializer(-config.init_scale,
                                                    config.init_scale)
        # ????????,????,????
        # reuse=None ???????
        with tf.variable_scope("model", reuse=None, initializer=initializer):
            m = PTBModel(is_training=True, config=config)
        # ???????
        with tf.variable_scope("model", reuse=True, initializer=initializer):
            mvalid = PTBModel(is_training=False, config=config)
            mtest = PTBModel(is_training=False, config=eval_config)

        tf.global_variables_initializer().run()

        for i in range(config.max_max_epoch):
            # 0.5**(0,..,0 and 1 and 2,.. )
            lr_decay = config.lr_decay ** max(i - config.max_epoch, 0.0)
            # ????max_epoch???????????
            m.assign_lr(session, config.learning_rate * lr_decay)

            print("Epoch: %d Learning rate: %.3f" % (i + 1, session.run(m.lr)))
            train_perplexity, train_accuracy = run_epoch(session, m, train_data, m.train_op,
                                                         verbose=True)
            print("Epoch: %d Train Perplexity: %.3f, Train Accuracy: %.3f"
                  % (i + 1, train_perplexity, train_accuracy))
            valid_perplexity, valid_accuracy = run_epoch(session, mvalid, valid_data, tf.no_op(), verbose=True)
            print("Epoch: %d Valid Perplexity: %.3f, Valid Accuracy: %.3f"
                  % (i + 1, valid_perplexity, valid_accuracy))

        test_perplexity, test_accuracy = run_epoch(session, mtest, test_data, tf.no_op(), verbose=True)
        print("Test Perplexity: %.3f, Test Accuracy: %.3f" % (test_perplexity, test_accuracy))
        saver = tf.train.Saver()
        save_path = saver.save(session, "./PTB_Model/PTB_Variables.ckpt")
        print("Save to path: ", save_path)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号