def decode():
with tf.Session() as sess:
# Create model and load parameters.
model = create_model(sess, True)
model.batch_size = 1 # We decode one sentence at a time.
# Load vocabularies.
src_lang_vocab_path = PATH_TO_DATA_FILES + FLAGS.src_lang + "_mapping%d.txt" % FLAGS.src_lang_vocab_size
dst_lang_vocab_path = PATH_TO_DATA_FILES + FLAGS.dst_lang + "_mapping%d.txt" % FLAGS.dst_lang_vocab_size
src_lang_vocab, _ = data_utils.initialize_vocabulary(src_lang_vocab_path)
_, rev_dst_lang_vocab = data_utils.initialize_vocabulary(dst_lang_vocab_path)
# Decode from standard input.
sys.stdout.write("> ")
sys.stdout.flush()
sentence = sys.stdin.readline()
while sentence:
# Get token-ids for the input sentence.
token_ids = data_utils.sentence_to_token_ids(tf.compat.as_bytes(sentence), src_lang_vocab)
# Which bucket does it belong to?
bucket_id = min([b for b in xrange(len(_buckets))
if _buckets[b][0] > len(token_ids)])
# Get a 1-element batch to feed the sentence to the model.
encoder_inputs, decoder_inputs, target_weights = model.get_batch(
{bucket_id: [(token_ids, [])]}, bucket_id)
# Get output logits for the sentence.
_, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs,
target_weights, bucket_id, True)
# This is a greedy decoder - outputs are just argmaxes of output_logits.
outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
# If there is an EOS symbol in outputs, cut them at that point.
if data_utils.EOS_ID in outputs:
outputs = outputs[:outputs.index(data_utils.EOS_ID)]
# Print out French sentence corresponding to outputs.
print(" ".join([tf.compat.as_str(rev_dst_lang_vocab[output]) for output in outputs]))
print("> ", end="")
sys.stdout.flush()
sentence = sys.stdin.readline()
评论列表
文章目录