def main(_):
pp.pprint(flags.FLAGS.__flags)
data_path = "./data/%s" % FLAGS.dataset
reader = TextReader(data_path)
with tf.Session() as sess:
m = MODELS[FLAGS.model]
model = m(sess, reader, dataset=FLAGS.dataset,
embed_dim=FLAGS.embed_dim, h_dim=FLAGS.h_dim,
learning_rate=FLAGS.learning_rate, max_iter=FLAGS.max_iter,
checkpoint_dir=FLAGS.checkpoint_dir)
if FLAGS.forward_only:
model.load(FLAGS.checkpoint_dir)
else:
model.train(FLAGS)
while True:
text = raw_input(" [*] Enter text to test: ")
model.sample(5, text)
评论列表
文章目录