main.py 文件源码

python
阅读 41 收藏 0 点赞 0 评论 0

项目:variational-text-tensorflow 作者: carpedm20 项目源码 文件源码
def main(_):
  pp.pprint(flags.FLAGS.__flags)

  data_path = "./data/%s" % FLAGS.dataset
  reader = TextReader(data_path)

  with tf.Session() as sess:
    m = MODELS[FLAGS.model]
    model = m(sess, reader, dataset=FLAGS.dataset,
              embed_dim=FLAGS.embed_dim, h_dim=FLAGS.h_dim,
              learning_rate=FLAGS.learning_rate, max_iter=FLAGS.max_iter,
              checkpoint_dir=FLAGS.checkpoint_dir)

    if FLAGS.forward_only:
      model.load(FLAGS.checkpoint_dir)
    else:
      model.train(FLAGS)

    while True:
      text = raw_input(" [*] Enter text to test: ")
      model.sample(5, text)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号