def predict(self,batch,randFlag):
t = [[bi] for bi in [1] * batch]
t = self.makeEmbedBatch(t)
ys_d = self.dec(t, train=False)
ys_w = [self.h2w(y) for y in ys_d]
name_arr_arr = []
if randFlag:
t = [predictRandom(F.softmax(y_each)) for y_each in ys_w]
else:
t = [y_each.data[-1].argmax(0) for y_each in ys_w]
name_arr_arr.append(t)
t = [self.embed(xp.array([t_each], dtype=xp.int32)) for t_each in t]
count_len = 0
while count_len < 50:
ys_d = self.dec(t, train=False)
ys_w = [self.h2w(y) for y in ys_d]
if randFlag:
t = [predictRandom(F.softmax(y_each)) for y_each in ys_w]
else:
t = [y_each.data[-1].argmax(0) for y_each in ys_w]
name_arr_arr.append(t)
t = [self.embed(xp.array([t_each], dtype=xp.int32)) for t_each in t]
count_len += 1
tenti = xp.array(name_arr_arr).T
for name in tenti:
name = [self.vocab.itos(nint) for nint in name]
if "</s>" in name:
print(" Gen:{}".format("".join(name[:name.index("</s>")])))
评论列表
文章目录