def predict(self, xs):
"""
batch: list of splitted sentences
"""
xs = [self.extractor.process(x) for x in xs]
batchsize = len(xs)
ws, cs, ls = zip(*xs)
ws = map(self.emb_word, ws)
cs = [F.squeeze(
F.max_pooling_2d(
self.conv_char(
F.expand_dims(
self.emb_char(c), 1)), (l, 1)))
for c, l in zip(cs, ls)]
xs_f = [F.dropout(F.concat([w, c]),
self.dropout_ratio, train=self.train) for w, c in zip(ws, cs)]
xs_b = [x[::-1] for x in xs_f]
cx_f, hx_f, cx_b, hx_b = self._init_state(batchsize)
_, _, hs_f = self.lstm_f(hx_f, cx_f, xs_f, train=self.train)
_, _, hs_b = self.lstm_b(hx_b, cx_b, xs_b, train=self.train)
hs_b = [x[::-1] for x in hs_b]
ys = [self.linear2(F.relu(self.linear1(F.concat([h_f, h_b]))))
for h_f, h_b in zip(hs_f, hs_b)]
return [y.data[1:-1] for y in ys]
评论列表
文章目录