pytorch_model.py 文件源码

python
阅读 40 收藏 0 点赞 0 评论 0

项目:biaffineparser 作者: chantera 项目源码 文件源码
def forward(self, pretrained_word_tokens, word_tokens, pos_tokens):
        lengths = np.array([len(tokens) for tokens in word_tokens])
        X = self.forward_embed(
            pretrained_word_tokens, word_tokens, pos_tokens, lengths)
        indices = np.argsort(-np.array(lengths)).astype(np.int64)
        lengths = lengths[indices]
        X = torch.stack([X[idx] for idx in indices])
        X = nn.utils.rnn.pack_padded_sequence(X, lengths, batch_first=True)
        R = self.blstm(X)[0]
        R = nn.utils.rnn.pad_packed_sequence(R, batch_first=True)[0]
        R = R.index_select(dim=0, index=_model_var(
            self, torch.from_numpy(np.argsort(indices).astype(np.int64))))
        H_arc_head = self.mlp_arc_head(R)
        H_arc_dep = self.mlp_arc_dep(R)
        arc_logits = self.arc_biaffine(H_arc_dep, H_arc_head)
        arc_logits = torch.squeeze(arc_logits, dim=3)
        H_label_dep = self.mlp_label_dep(R)
        H_label_head = self.mlp_label_head(R)
        label_logits = self.label_biaffine(H_label_dep, H_label_head)
        return arc_logits, label_logits
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号