seq2seq.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:basic-encoder-decoder 作者: pemywei 项目源码 文件源码
def test(self, sess, token_ids):
        # We decode one sentence at a time.
        token_ids = data_utils.padding(token_ids)
        target_ids = data_utils.padding([data_utils.GO_ID])
        y_ids = data_utils.padding([data_utils.EOS_ID])
        encoder_inputs, decoder_inputs, _, _ = data_utils.nextRandomBatch([(token_ids, target_ids, y_ids)], batch_size=1)
        prediction = sess.run(self.prediction, feed_dict={
            self.encoder_inputs: encoder_inputs,
            self.decoder_inputs: decoder_inputs
        })
        pred_max = tf.arg_max(prediction, 1)
        # prediction = tf.split(0, self.num_steps, prediction)
        # # This is a greedy decoder - outputs are just argmaxes of output_logits.
        # outputs = [int(np.argmax(predict)) for predict in prediction]
        # # If there is an EOS symbol in outputs, cut them at that point.
        # if data_utils.EOS_ID in outputs:
        #     outputs = outputs[:outputs.index(data_utils.EOS_ID)]
        return pred_max.eval()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号