def main(_):
if not FLAGS.data_path:
raise ValueError("Must set --data_path to PTB data directory")
raw_data = reader.ptb_raw_data(FLAGS.data_path)
train_data, valid_data, test_data, _ = raw_data
config = get_config()
eval_config = get_config()
eval_config.batch_size = 1
eval_config.num_steps = 1
if config.device == '-1':
tf_dev = '/cpu:0'
else:
tf_dev = '/gpu:' + config.device
print(tf_dev)
tconfig = tf.ConfigProto(allow_soft_placement=True)
if tf_dev.find('cpu') >= 0: # cpu version
num_threads = os.getenv('OMP_NUM_THREADS', 1)
tconfig = tf.ConfigProto(allow_soft_placement=True, intra_op_parallelism_threads=int(num_threads))
with tf.Graph().as_default(), tf.device(tf_dev), tf.Session(config=tconfig) as session:
initializer = tf.random_uniform_initializer(-config.init_scale,
config.init_scale)
with tf.variable_scope("model", reuse=None, initializer=initializer):
m = PTBModel(is_training=True, config=config)
with tf.variable_scope("model", reuse=True, initializer=initializer):
#mvalid = PTBModel(is_training=False, config=config)
mtest = PTBModel(is_training=False, config=eval_config)
tf.global_variables_initializer().run()
total_average_batch_time = 0.0
epochs_info = []
for i in range(config.max_max_epoch):
#lr_decay = config.lr_decay ** max(i - config.max_epoch, 0.0)
#m.assign_lr(session, config.learning_rate * lr_decay)
m.assign_lr(session, config.learning_rate)
print("Epoch: %d Learning rate: %.3f" % (i + 1, session.run(m.lr)))
train_perplexity, average_batch_time = run_epoch(session, m, train_data, m.train_op, verbose=True)
total_average_batch_time += average_batch_time
print("Epoch: %d Train Perplexity: %.3f" % (i + 1, train_perplexity))
if i % 2 == 0:
epochs_info.append('%d:_:%.3f'%(i, train_perplexity))
# valid_perplexity = run_epoch(session, mvalid, valid_data, tf.no_op())
# print("Epoch: %d Valid Perplexity: %.3f" % (i + 1, valid_perplexity))
print("average_batch_time: %.6f" % (total_average_batch_time/int(config.max_max_epoch)))
print('epoch_info:'+','.join(epochs_info))
test_perplexity, test_average_batch_time = run_epoch(session, mtest, test_data, tf.no_op())
print("Test Perplexity: %.3f" % test_perplexity)
评论列表
文章目录