def get_predicted_sentence(input_sentence, vocab, rev_vocab, model, sess):
input_token_ids = data_utils.sentence_to_token_ids(input_sentence, vocab)
print(input_token_ids)
# Which bucket does it belong to?
if len(input_token_ids)>=BUCKETS[-1][0]:
input_token_ids = input_token_ids[:BUCKETS[-1][0]-1]
bucket_id = min([b for b in xrange(len(BUCKETS)) if BUCKETS[b][0] > len(input_token_ids)])
outputs = []
feed_data = {bucket_id: [(input_token_ids, outputs)]}
# Get a 1-element batch to feed the sentence to the model.
encoder_inputs, decoder_inputs, target_weights = model.get_batch(feed_data, bucket_id)
global_memory['inp']=1
# Get output logits for the sentence.
_,_,output_logits = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, forward_only=True,beam_search=True)
#print('global_output:')
#print(global_output)
outputs = []
# This is a greedy decoder - outputs are just argmaxes of output_logits.
for logit in output_logits:
selected_token_id = int(np.argmax(logit, axis=1))
if selected_token_id == data_utils.EOS_ID:
break
else:
outputs.append(selected_token_id)
# Forming output sentence on natural language
outputs = ' '.join([rev_vocab[i] for i in outputs])
return outputs
seq2seq_model_utils.py 文件源码
python
阅读 19
收藏 0
点赞 0
评论 0
评论列表
文章目录