tf_flask_api.py 文件源码

python
阅读 15 收藏 0 点赞 0 评论 0

项目:Seq2Seq-Tensorflow-1.0-Chatbot 作者: igorvishnevskiy 项目源码 文件源码
def post(self):
        data_received = request.json
        if not data_received:
            data_received = eval(request.form["payload"])

        sentence = data_received["text"]
        print(sentence)

        token_ids = data_utils.sentence_to_token_ids(tf.compat.as_bytes(sentence), en_vocab)
        # Which bucket does it belong to?
        bucket_id = len(_buckets) - 1
        for i, bucket in enumerate(_buckets):
            if bucket[0] >= len(token_ids):
                bucket_id = i
                break
        else:
            logging.warning("Sentence truncated: %s", sentence)

        # Get a 1-element batch to feed the sentence to the model.
        encoder_inputs, decoder_inputs, target_weights = model.get_batch(
          {bucket_id: [(token_ids, [])]}, bucket_id)
        # Get output logits for the sentence.
        _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs,
                                       target_weights, bucket_id, True)
        # This is a greedy decoder - outputs are just argmaxes of output_logits.
        outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
        # If there is an EOS symbol in outputs, cut them at that point.
        if data_utils.EOS_ID in outputs:
            outputs = outputs[:outputs.index(data_utils.EOS_ID)]
        # Print out French sentence corresponding to outputs.
        response = (" ".join([tf.compat.as_str(rev_fr_vocab[output]) for output in outputs]))
        print(response)

        return jsonify({"text":response})
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号