def test_decoder(config):
train_path = os.path.join(config.train_dir, "chitchat.train")
data_path_list = [train_path + ".answer", train_path + ".query"]
vocab_path = os.path.join(config.train_dir, "vocab%d.all" % config.vocab_size)
data_utils.create_vocabulary(vocab_path, data_path_list, config.vocab_size)
vocab, rev_vocab = data_utils.initialize_vocabulary(vocab_path)
with tf.Session() as sess:
if config.name_model in [gst_config.name_model, gcc_config.name_model, gbk_config.name_model]:
model = create_st_model(sess, config, forward_only=True, name_scope=config.name_model)
elif config.name_model in [grl_config.name_model, pre_grl_config.name_model]:
model = create_rl_model(sess, config, forward_only=True, name_scope=config.name_model)
model.batch_size = 1
sys.stdout.write("> ")
sys.stdout.flush()
sentence = sys.stdin.readline()
while sentence:
token_ids = data_utils.sentence_to_token_ids(tf.compat.as_bytes(sentence), vocab)
print("token_id: ", token_ids)
bucket_id = len(config.buckets) - 1
for i, bucket in enumerate(config.buckets):
if bucket[0] >= len(token_ids):
bucket_id = i
break
else:
print("Sentence truncated: %s", sentence)
encoder_inputs, decoder_inputs, target_weights, _, _ = model.get_batch({bucket_id: [(token_ids, [1])]},
bucket_id)
# st_model step
if config.name_model in [gst_config.name_model, gcc_config.name_model, gbk_config.name_model]:
output_logits, _ = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, True)
outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
if data_utils.EOS_ID in outputs:
outputs = outputs[:outputs.index(data_utils.EOS_ID)]
print(" ".join([str(rev_vocab[output]) for output in outputs]))
# beam_search step
elif config.name_model in [grl_config.name_model, pre_grl_config.name_model]:
_, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs, target_weights, reward=1,
bucket_id=bucket_id, forward_only=True)
for i, output in enumerate(output_logits):
print("index: %d, answer tokens: %s" %(i, str(output)))
if data_utils.EOS_ID in output:
output = output[:output.index(data_utils.EOS_ID)]
print(" ".join([str(rev_vocab[out]) for out in output]))
print("> ", end="")
sys.stdout.flush()
sentence = sys.stdin.readline()
grl_train.py 文件源码
python
阅读 15
收藏 0
点赞 0
评论 0
评论列表
文章目录