def forward(self, pretrained_word_tokens, word_tokens, pos_tokens):
X = []
batch = len(word_tokens)
for i in range(batch):
xs_words_pretrained = \
self.embed[0](self.xp.array(pretrained_word_tokens[i]))
xs_words = self.embed[1](self.xp.array(word_tokens[i]))
xs_words += xs_words_pretrained
xs_tags = self.embed[2](self.xp.array(pos_tokens[i]))
xs = F.concat([
teras_F.dropout(xs_words, self.embed._dropout_ratio),
teras_F.dropout(xs_tags, self.embed._dropout_ratio)])
X.append(xs)
R = self.blstm(X)
R = F.pad_sequence(R)
H_arc_dep = self.mlp_arc_dep(R)
H_arc_head = self.mlp_arc_head(R)
arc_logits = self.arc_biaffine(H_arc_dep, H_arc_head)
arc_logits = F.squeeze(arc_logits, axis=3)
H_label_dep = self.mlp_label_dep(R)
H_label_head = self.mlp_label_head(R)
label_logits = self.label_biaffine(H_label_dep, H_label_head)
return arc_logits, label_logits
评论列表
文章目录