def main(_):
print("Loading vocabulary")
cn_vocab_path = os.path.join(FLAGS.data_dir, "source_vocab.txt")
en_vocab_path = os.path.join(FLAGS.data_dir, "target_vocab.txt")
cn_vocab, _ = data_utils.initialize_vocabulary(cn_vocab_path)
_, rev_en_vocab = data_utils.initialize_vocabulary(en_vocab_path)
print("Building model...")
config = tf.ConfigProto(allow_soft_placement=True)
with tf.Session(config=config) as sess:
model = create_model(sess, False)
# Decode from standard input.
sys.stdout.write("> ")
sys.stdout.flush()
sentence = sys.stdin.readline()
while sentence:
seg_list = jieba.lcut(sentence.strip())
#print(" ".join(seg_list))
token_ids = [cn_vocab.get(w.encode(encoding="utf-8"), data_utils.UNK_ID) for w in seg_list]
#print(token_ids)
outputs = model.test(sess, token_ids)
outputs = outputs.tolist()
if data_utils.EOS_ID in outputs:
outputs = outputs[:outputs.index(data_utils.EOS_ID)]
output = " ".join([tf.compat.as_str(rev_en_vocab[output]) for output in outputs])
print(output.capitalize())
print("> ")
sys.stdout.flush()
sentence = sys.stdin.readline()
评论列表
文章目录