trigger.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:Seq2Seq-Chatbot 作者: FR0ST1N 项目源码 文件源码
def decode():
  gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.2)
  config = tf.ConfigProto(gpu_options=gpu_options)

  with tf.Session(config=config) as sess:
    model = create_model(sess, True)
    model.batch_size = 1 
    enc_vocab_path = os.path.join(working_directory,"vocab%d.enc" % enc_vocab_size)
    dec_vocab_path = os.path.join(working_directory,"vocab%d.dec" % dec_vocab_size)

    enc_vocab, _ = data_utils.initialize_vocabulary(enc_vocab_path)
    _, rev_dec_vocab = data_utils.initialize_vocabulary(dec_vocab_path)
    sys.stdout.write("> ")
    sys.stdout.flush()
    sentence = sys.stdin.readline()
    while sentence:
      token_ids = data_utils.sentence_to_token_ids(tf.compat.as_bytes(sentence), enc_vocab)
      bucket_id = min([b for b in xrange(len(_buckets))
                       if _buckets[b][0] > len(token_ids)])
      encoder_inputs, decoder_inputs, target_weights = model.get_batch(
          {bucket_id: [(token_ids, [])]}, bucket_id)
      _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs,
                                       target_weights, bucket_id, True)
      outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
      if data_utils.EOS_ID in outputs:
        outputs = outputs[:outputs.index(data_utils.EOS_ID)]
      if sentence[:-1] in lines:
        temp_output = " ".join([tf.compat.as_str(rev_dec_vocab[output]) for output in outputs])
        trigger_check = trigger_activator(temp_output)
        if trigger_check == True:
            print(" ".join([tf.compat.as_str(rev_dec_vocab[output]) for output in outputs[:-1]]))
        else:
            print(temp_output)
      else:
          print('i dont understand you')      
      print("> ", end="")
      sys.stdout.flush()
      sentence = sys.stdin.readline()

#Check if there is a trigger in the decoded sentence
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号