chainer_model.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:biaffineparser 作者: chantera 项目源码 文件源码
def forward(self, pretrained_word_tokens, word_tokens, pos_tokens):
        X = []
        batch = len(word_tokens)
        for i in range(batch):
            xs_words_pretrained = \
                self.embed[0](self.xp.array(pretrained_word_tokens[i]))
            xs_words = self.embed[1](self.xp.array(word_tokens[i]))
            xs_words += xs_words_pretrained
            xs_tags = self.embed[2](self.xp.array(pos_tokens[i]))
            xs = F.concat([
                teras_F.dropout(xs_words, self.embed._dropout_ratio),
                teras_F.dropout(xs_tags, self.embed._dropout_ratio)])
            X.append(xs)
        R = self.blstm(X)
        R = F.pad_sequence(R)
        H_arc_dep = self.mlp_arc_dep(R)
        H_arc_head = self.mlp_arc_head(R)
        arc_logits = self.arc_biaffine(H_arc_dep, H_arc_head)
        arc_logits = F.squeeze(arc_logits, axis=3)
        H_label_dep = self.mlp_label_dep(R)
        H_label_head = self.mlp_label_head(R)
        label_logits = self.label_biaffine(H_label_dep, H_label_head)
        return arc_logits, label_logits
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号