def main(_):
if not FLAGS.data_path:
raise ValueError("Must set --data_path to PTB data directory")
raw_data = reader.ptb_raw_data(FLAGS.data_path, True)
train_data, valid_data, _ = raw_data
with tf.Graph().as_default():
initializer = tf.random_uniform_initializer(-FLAGS.init_scale,
FLAGS.init_scale)
with tf.name_scope("Train"):
train_input = PTBInput(data=train_data, name="TrainInput")
with tf.variable_scope("Model", reuse=None, initializer=initializer):
m = PTBModel(is_training=True, input_=train_input)
tf.summary.scalar("Training Loss", m.cost)
tf.summary.scalar("Learning Rate", m.lr)
with tf.name_scope("Train_states"):
train_input = PTBInput(data=train_data, name="TrainInput")
with tf.variable_scope("Model", reuse=True, initializer=initializer):
mstates = PTBModel(is_training=False, input_=train_input)
tf.summary.scalar("Training Loss", mstates.cost)
with tf.name_scope("Valid"):
valid_input = PTBInput(data=valid_data, name="ValidInput")
with tf.variable_scope("Model", reuse=True, initializer=initializer):
mvalid = PTBModel(is_training=False, input_=valid_input)
tf.summary.scalar("Validation Loss", mvalid.cost)
sv = tf.train.Supervisor(logdir=FLAGS.save_path)
with sv.managed_session() as session:
if FLAGS.load_path:
sv.saver.restore(session, tf.train.latest_checkpoint(FLAGS.load_path))
else:
for i in range(FLAGS.max_max_epoch):
lr_decay = FLAGS.lr_decay ** max(i + 1 - FLAGS.max_epoch, 0.0)
m.assign_lr(session, FLAGS.learning_rate * lr_decay)
print("Epoch: %d Learning rate: %.3f" % (i + 1, session.run(m.lr)))
train_perplexity, stat = run_epoch(session, m, eval_op=m.train_op,
verbose=True)
print(stat.shape)
print("Epoch: %d Train Perplexity: %.3f" % (i + 1, train_perplexity))
valid_perplexity, stat = run_epoch(session, mvalid)
print("Epoch: %d Valid Perplexity: %.3f" % (i + 1, valid_perplexity))
# run and store the states on training set
train_perplexity, stat = run_epoch(session, mstates, eval_op=m.train_op,
verbose=True)
f = h5py.File("states.h5", "w")
stat = np.reshape(stat, (-1, mstates.size))
f["states1"] = stat
f.close()
if FLAGS.save_path:
print("Saving model to %s." % FLAGS.save_path)
sv.saver.save(session, FLAGS.save_path, global_step=sv.global_step)
评论列表
文章目录