def decode():
""" Decode file sentence-by-sentence """
with tf.Session(config=tf.ConfigProto(intra_op_parallelism_threads=NUM_THREADS)) as sess:
# Create model and load parameters.
with tf.variable_scope("model"):
model, steps_done = create_model(sess, True, attention=FLAGS.attention, model_path=FLAGS.model_path)
model.batch_size = 1 # We decode one sentence at a time.
# Load vocabularies.
sents_vocab_path = os.path.join(FLAGS.data_dir, "vocab%d.sents" % FLAGS.input_vocab_size)
parse_vocab_path = os.path.join(FLAGS.data_dir, "vocab%d.parse" % FLAGS.output_vocab_size)
sents_vocab, _ = data_utils.initialize_vocabulary(sents_vocab_path)
_, rev_parse_vocab = data_utils.initialize_vocabulary(parse_vocab_path)
start_time = time.time()
# Decode
with open(FLAGS.decode_input_path, 'r') as fin, open(FLAGS.decode_output_path, 'w') as fout:
for line in fin:
sentence = line.strip()
token_ids = data_utils.sentence_to_token_ids(tf.compat.as_bytes(sentence), sents_vocab)
try:
bucket_id = min([b for b in xrange(len(_buckets)) if _buckets[b][0] > len(token_ids)])
except:
print("Input sentence does not fit in any buckets. Skipping... ")
print("\t", line)
continue
encoder_inputs, decoder_inputs, target_weights = model.get_batch({bucket_id: [(token_ids, [])]}, bucket_id)
_, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs,target_weights, bucket_id, True)
outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
if data_utils.EOS_ID in outputs:
outputs = outputs[:outputs.index(data_utils.EOS_ID)]
decoded_sentence = " ".join([tf.compat.as_str(rev_parse_vocab[output]) for output in outputs]) + '\n'
fout.write(decoded_sentence)
time_elapsed = time.time() - start_time
print("Decoding time: ", time_elapsed)
评论列表
文章目录