seq2seq.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:seq2seq 作者: yanwii 项目源码 文件源码
def predict(self):
        try:
            self.load_state_dict(torch.load(self.model_path+'params.pkl'))
        except Exception as e:
            print(e)
            print("No model!")
        loss_track = []

        # ????
        str_to_vec = {}
        with open("./data/enc.vocab") as enc_vocab:
            for index,word in enumerate(enc_vocab.readlines()):
                str_to_vec[word.strip()] = index

        vec_to_str = {}
        with open("./data/dec.vocab") as dec_vocab:
            for index,word in enumerate(dec_vocab.readlines()):
                vec_to_str[index] = word.strip()

        while True:
            input_strs = input("me > ")
            # ??????
            segement = jieba.lcut(input_strs)
            input_vec = [str_to_vec.get(i, 3) for i in segement]
            input_vec = self.make_infer_fd(input_vec)

            # inference
            if self.beam_search:
                samples = self.beamSearchDecoder(input_vec)
                for sample in samples:
                    outstrs = []
                    for i in sample[0]:
                        if i == 1:
                            break
                        outstrs.append(vec_to_str.get(i, "Un"))
                    print("ai > ", "".join(outstrs), sample[3])
            else:
                logits = self.infer(input_vec)
                _,v = torch.topk(logits, 1)
                pre = v.cpu().data.numpy().T.tolist()[0][0]
                outstrs = []
                for i in pre:
                    if i == 1:
                        break
                    outstrs.append(vec_to_str.get(i, "Un"))
                print("ai > ", "".join(outstrs))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号