def predict(self):
try:
self.load_state_dict(torch.load(self.model_path+'params.pkl'))
except Exception as e:
print(e)
print("No model!")
loss_track = []
# ????
str_to_vec = {}
with open("./data/enc.vocab") as enc_vocab:
for index,word in enumerate(enc_vocab.readlines()):
str_to_vec[word.strip()] = index
vec_to_str = {}
with open("./data/dec.vocab") as dec_vocab:
for index,word in enumerate(dec_vocab.readlines()):
vec_to_str[index] = word.strip()
while True:
input_strs = input("me > ")
# ??????
segement = jieba.lcut(input_strs)
input_vec = [str_to_vec.get(i, 3) for i in segement]
input_vec = self.make_infer_fd(input_vec)
# inference
if self.beam_search:
samples = self.beamSearchDecoder(input_vec)
for sample in samples:
outstrs = []
for i in sample[0]:
if i == 1:
break
outstrs.append(vec_to_str.get(i, "Un"))
print("ai > ", "".join(outstrs), sample[3])
else:
logits = self.infer(input_vec)
_,v = torch.topk(logits, 1)
pre = v.cpu().data.numpy().T.tolist()[0][0]
outstrs = []
for i in pre:
if i == 1:
break
outstrs.append(vec_to_str.get(i, "Un"))
print("ai > ", "".join(outstrs))
评论列表
文章目录