def predict(self, xs):
"""
batch: list of splitted sentences
"""
batchsize = len(xs)
xs = [self.extractor.process(x) for x in xs]
ws, ss, ps = concat_examples(xs, padding=IGNORE)
cat_ys, dep_ys = self.forward(ws, ss, ps)
cat_ys = F.transpose(F.stack(cat_ys, 2), (0, 2, 1))
dep_ys = F.transpose(F.stack(dep_ys, 2), (0, 2, 1))
cat_ys = [F.squeeze(y, 0).data[1:len(x) + 1] for x, y in \
zip(xs, F.split_axis(cat_ys, batchsize, 0))]
dep_ys = [F.squeeze(F.log_softmax(y[1:len(x) + 1, :-1]), 0).data \
for x, y in zip(xs, F.split_axis(dep_ys, batchsize, 0))]
return cat_ys, dep_ys
评论列表
文章目录