def test(self, sess, token_ids):
# We decode one sentence at a time.
token_ids = data_utils.padding(token_ids)
target_ids = data_utils.padding([data_utils.GO_ID])
y_ids = data_utils.padding([data_utils.EOS_ID])
encoder_inputs, decoder_inputs, _, _ = data_utils.nextRandomBatch([(token_ids, target_ids, y_ids)], batch_size=1)
prediction = sess.run(self.prediction, feed_dict={
self.encoder_inputs: encoder_inputs,
self.decoder_inputs: decoder_inputs
})
pred_max = tf.arg_max(prediction, 1)
# prediction = tf.split(0, self.num_steps, prediction)
# # This is a greedy decoder - outputs are just argmaxes of output_logits.
# outputs = [int(np.argmax(predict)) for predict in prediction]
# # If there is an EOS symbol in outputs, cut them at that point.
# if data_utils.EOS_ID in outputs:
# outputs = outputs[:outputs.index(data_utils.EOS_ID)]
return pred_max.eval()
评论列表
文章目录