test.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:basic-encoder-decoder 作者: pemywei 项目源码 文件源码
def main(_):
    print("Loading vocabulary")
    cn_vocab_path = os.path.join(FLAGS.data_dir, "source_vocab.txt")
    en_vocab_path = os.path.join(FLAGS.data_dir, "target_vocab.txt")

    cn_vocab, _ = data_utils.initialize_vocabulary(cn_vocab_path)
    _, rev_en_vocab = data_utils.initialize_vocabulary(en_vocab_path)

    print("Building model...")
    config = tf.ConfigProto(allow_soft_placement=True)
    with tf.Session(config=config) as sess:
        model = create_model(sess, False)
        # Decode from standard input.
        sys.stdout.write("> ")
        sys.stdout.flush()
        sentence = sys.stdin.readline()
        while sentence:
            seg_list = jieba.lcut(sentence.strip())
            #print(" ".join(seg_list))
            token_ids = [cn_vocab.get(w.encode(encoding="utf-8"), data_utils.UNK_ID) for w in seg_list]
            #print(token_ids)
            outputs = model.test(sess, token_ids)
            outputs = outputs.tolist()
            if data_utils.EOS_ID in outputs:
                outputs = outputs[:outputs.index(data_utils.EOS_ID)]
            output = " ".join([tf.compat.as_str(rev_en_vocab[output]) for output in outputs])
            print(output.capitalize())
            print("> ")
            sys.stdout.flush()
            sentence = sys.stdin.readline()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号