seq2seq_model_utils.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:Biseq2Seq_NLG 作者: MaZhiyuanBUAA 项目源码 文件源码
def get_predicted_sentence(input_sentence, vocab, rev_vocab, model, sess):
    input_token_ids = data_utils.sentence_to_token_ids(input_sentence, vocab)
    print(input_token_ids)
    # Which bucket does it belong to?
    if len(input_token_ids)>=BUCKETS[-1][0]:
      input_token_ids = input_token_ids[:BUCKETS[-1][0]-1]
    bucket_id = min([b for b in xrange(len(BUCKETS)) if BUCKETS[b][0] > len(input_token_ids)])
    outputs = []

    feed_data = {bucket_id: [(input_token_ids, outputs)]}
    # Get a 1-element batch to feed the sentence to the model.
    encoder_inputs, decoder_inputs, target_weights = model.get_batch(feed_data, bucket_id)
    global_memory['inp']=1
    # Get output logits for the sentence.
    _,_,output_logits = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, forward_only=True,beam_search=True)
    #print('global_output:')
    #print(global_output)
    outputs = []
    # This is a greedy decoder - outputs are just argmaxes of output_logits.
    for logit in output_logits:
        selected_token_id = int(np.argmax(logit, axis=1))
        if selected_token_id == data_utils.EOS_ID:
                break
        else:
            outputs.append(selected_token_id)
    # Forming output sentence on natural language
    outputs = ' '.join([rev_vocab[i] for i in outputs])

    return outputs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号