def decode():
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.2)
config = tf.ConfigProto(gpu_options=gpu_options)
with tf.Session(config=config) as sess:
model = create_model(sess, True)
model.batch_size = 1
enc_vocab_path = os.path.join(working_directory,"vocab%d.enc" % enc_vocab_size)
dec_vocab_path = os.path.join(working_directory,"vocab%d.dec" % dec_vocab_size)
enc_vocab, _ = data_utils.initialize_vocabulary(enc_vocab_path)
_, rev_dec_vocab = data_utils.initialize_vocabulary(dec_vocab_path)
print('Start chatting...')
@bot.message_handler(func=lambda sentence: True)
def reply_all(message):
sentence = (message.text).lower()
token_ids = data_utils.sentence_to_token_ids(tf.compat.as_bytes(sentence), enc_vocab)
bucket_id = min([b for b in xrange(len(_buckets))
if _buckets[b][0] > len(token_ids)])
encoder_inputs, decoder_inputs, target_weights = model.get_batch(
{bucket_id: [(token_ids, [])]}, bucket_id)
_, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs,
target_weights, bucket_id, True)
outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
if data_utils.EOS_ID in outputs:
outputs = outputs[:outputs.index(data_utils.EOS_ID)]
message_text = " ".join([tf.compat.as_str(rev_dec_vocab[output]) for output in outputs])
bot.reply_to(message, message_text)
while True:
try:
bot.polling(none_stop=True)
except Exception as ex:
print(str(ex))
bot.stop_polling()
time.sleep(5)
bot.polling()
评论列表
文章目录