def main():
# ?? ???,???,???
raw_data = reader.ptb_raw_data('/home/feizhihui/MyData/dataset/PTB/')
train_data, valid_data, test_data, _ = raw_data
# ??????config{small,medium,large,or test}
# ??2?????,?????config??????
config = Config()
eval_config = Config()
eval_config.batch_size = 1
eval_config.num_steps = 1
with tf.Graph().as_default(), tf.Session() as session:
# ??????????
initializer = tf.random_uniform_initializer(-config.init_scale,
config.init_scale)
# ????????,????,????
# reuse=None ???????
with tf.variable_scope("model", reuse=None, initializer=initializer):
m = PTBModel(is_training=True, config=config)
# ???????
with tf.variable_scope("model", reuse=True, initializer=initializer):
mvalid = PTBModel(is_training=False, config=config)
mtest = PTBModel(is_training=False, config=eval_config)
tf.global_variables_initializer().run()
for i in range(config.max_max_epoch):
# 0.5**(0,..,0 and 1 and 2,.. )
lr_decay = config.lr_decay ** max(i - config.max_epoch, 0.0)
# ????max_epoch???????????
m.assign_lr(session, config.learning_rate * lr_decay)
print("Epoch: %d Learning rate: %.3f" % (i + 1, session.run(m.lr)))
train_perplexity, train_accuracy = run_epoch(session, m, train_data, m.train_op,
verbose=True)
print("Epoch: %d Train Perplexity: %.3f, Train Accuracy: %.3f"
% (i + 1, train_perplexity, train_accuracy))
valid_perplexity, valid_accuracy = run_epoch(session, mvalid, valid_data, tf.no_op(), verbose=True)
print("Epoch: %d Valid Perplexity: %.3f, Valid Accuracy: %.3f"
% (i + 1, valid_perplexity, valid_accuracy))
test_perplexity, test_accuracy = run_epoch(session, mtest, test_data, tf.no_op(), verbose=True)
print("Test Perplexity: %.3f, Test Accuracy: %.3f" % (test_perplexity, test_accuracy))
saver = tf.train.Saver()
save_path = saver.save(session, "./PTB_Model/PTB_Variables.ckpt")
print("Save to path: ", save_path)
评论列表
文章目录