def main(_):
if not FLAGS.data_path:
raise ValueError("Must set --data_path to PTB data directory")
raw_data = reader.ptb_raw_data(FLAGS.data_path)
train_data, valid_data, test_data, _ = raw_data
config = get_config()
eval_config = get_config()
eval_config.batch_size = 1
eval_config.num_steps = 1
with tf.Graph().as_default(), tf.Session() as session:
initializer = tf.uniform_unit_scaling_initializer()
with tf.variable_scope("model", reuse=None, initializer=initializer):
m = PTBModel(is_training=True, config=config)
with tf.variable_scope("model", reuse=True, initializer=initializer):
mvalid = PTBModel(is_training=False, config=config)
mtest = PTBModel(is_training=False, config=eval_config)
tf.global_variables_initializer().run()
def get_learning_rate(epoch, config):
base_lr = config.learning_rate
if epoch <= config.nr_epoch_first_stage:
return base_lr
elif epoch <= config.nr_epoch_second_stage:
return base_lr * 0.1
else:
return base_lr * 0.01
for i in range(config.max_epoch):
m.assign_lr(session, get_learning_rate(i, config))
print("Epoch: %d Learning rate: %f"
% (i + 1, session.run(m.lr)))
train_perplexity = run_epoch(
session, m, train_data, m.train_op, verbose=True)
print("Epoch: %d Train Perplexity: %.3f"
% (i + 1, train_perplexity))
valid_perplexity = run_epoch(
session, mvalid, valid_data, tf.no_op())
print("Epoch: %d Valid Perplexity: %.3f"
% (i + 1, valid_perplexity))
test_perplexity = run_epoch(
session, mtest, test_data, tf.no_op())
print("Test Perplexity: %.3f" % test_perplexity)
评论列表
文章目录