def forward(self, pretrained_word_tokens, word_tokens, pos_tokens):
lengths = np.array([len(tokens) for tokens in word_tokens])
X = self.forward_embed(
pretrained_word_tokens, word_tokens, pos_tokens, lengths)
indices = np.argsort(-np.array(lengths)).astype(np.int64)
lengths = lengths[indices]
X = torch.stack([X[idx] for idx in indices])
X = nn.utils.rnn.pack_padded_sequence(X, lengths, batch_first=True)
R = self.blstm(X)[0]
R = nn.utils.rnn.pad_packed_sequence(R, batch_first=True)[0]
R = R.index_select(dim=0, index=_model_var(
self, torch.from_numpy(np.argsort(indices).astype(np.int64))))
H_arc_head = self.mlp_arc_head(R)
H_arc_dep = self.mlp_arc_dep(R)
arc_logits = self.arc_biaffine(H_arc_dep, H_arc_head)
arc_logits = torch.squeeze(arc_logits, dim=3)
H_label_dep = self.mlp_label_dep(R)
H_label_head = self.mlp_label_head(R)
label_logits = self.label_biaffine(H_label_dep, H_label_head)
return arc_logits, label_logits
评论列表
文章目录