def forward(self, ws, ss, ps, dep_ts=None):
batchsize = len(ws)
xp = chainer.cuda.get_array_module(ws[0])
split = scanl(lambda x,y: x+y, 0, [w.shape[0] for w in ws])[1:-1]
wss = self.emb_word(F.hstack(ws))
sss = F.reshape(self.emb_suf(F.vstack(ss)), (-1, 4 * self.afix_dim))
pss = F.reshape(self.emb_prf(F.vstack(ps)), (-1, 4 * self.afix_dim))
ins = F.dropout(F.concat([wss, sss, pss]), self.dropout_ratio, train=self.train)
xs_f = list(F.split_axis(ins, split, 0))
xs_b = [x[::-1] for x in xs_f]
cx_f, hx_f, cx_b, hx_b = self._init_state(xp, batchsize)
_, _, hs_f = self.lstm_f(hx_f, cx_f, xs_f, train=self.train)
_, _, hs_b = self.lstm_b(hx_b, cx_b, xs_b, train=self.train)
hs_b = [x[::-1] for x in hs_b]
# ys: [(sentence length, number of category)]
hs = [F.concat([h_f, h_b]) for h_f, h_b in zip(hs_f, hs_b)]
dep_ys = [self.biaffine_arc(
F.elu(F.dropout(self.arc_dep(h), 0.32, train=self.train)),
F.elu(F.dropout(self.arc_head(h), 0.32, train=self.train))) for h in hs]
# if dep_ts is not None and random.random >= 0.5:
if dep_ts is not None:
heads = dep_ts
else:
heads = [F.argmax(y, axis=1) for y in dep_ys]
heads = F.elu(F.dropout(
self.rel_head(
F.vstack([F.embed_id(t, h, ignore_label=IGNORE) \
for h, t in zip(hs, heads)])),
0.32, train=self.train))
childs = F.elu(F.dropout(self.rel_dep(F.vstack(hs)), 0.32, train=self.train))
cat_ys = self.biaffine_tag(childs, heads)
cat_ys = list(F.split_axis(cat_ys, split, 0))
return cat_ys, dep_ys
评论列表
文章目录