main.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:seq2seq_translator 作者: jtoy 项目源码 文件源码
def decode():
  with tf.Session() as sess:
    # Create model and load parameters.
    model = create_model(sess, True)
    model.batch_size = 1  # We decode one sentence at a time.

    # Load vocabularies.
    first_vocab_path = os.path.join(FLAGS.train_dir,
                                 "vocab%d.first" % FLAGS.first_vocab_size)
    last_vocab_path = os.path.join(FLAGS.train_dir,
                                 "vocab%d.last" % FLAGS.last_vocab_size)
    first_vocab, _ = data_utils.initialize_vocabulary(first_vocab_path)
    _, rev_last_vocab = data_utils.initialize_vocabulary(last_vocab_path)

    # Decode from standard input.
    sys.stdout.write("> ")
    sys.stdout.flush()
    sentence = FLAGS.input
    # Get token-ids for the input sentence.
    token_ids = data_utils.sentence_to_token_ids(sentence, first_vocab)
    # Which bucket does it belong to?
    bucket_id = min([b for b in xrange(len(_buckets))
                     if _buckets[b][0] > len(token_ids)])
    # Get a 1-element batch to feed the sentence to the model.
    encoder_inputs, decoder_inputs, target_weights = model.get_batch(
        {bucket_id: [(token_ids, [])]}, bucket_id)
    # Get output logits for the sentence.
    _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs,
                                     target_weights, bucket_id, True)
    # This is a greedy decoder - outputs are just argmaxes of output_logits.
    outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
    # If there is an EOS symbol in outputs, cut them at that point.
    if data_utils.EOS_ID in outputs:
      outputs = outputs[:outputs.index(data_utils.EOS_ID)]
    # Print out French sentence corresponding to outputs.
    result = (" ".join([rev_last_vocab[output] for output in outputs]))
    print(result)
    output = os.path.join(FLAGS.output_dir,  str(int(time.time())) + ".txt")
    with open(output, "w") as text_file:
      text_file.write(result)
    print(output)
    sys.stdout.flush()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号