chainer_model.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:biaffineparser 作者: chantera 项目源码 文件源码
def parse(self, pretrained_word_tokens=None,
              word_tokens=None, pos_tokens=None):
        if word_tokens is not None:
            self.forward(pretrained_word_tokens, word_tokens, pos_tokens)
        ROOT = self._ROOT_LABEL
        arcs_batch, labels_batch = [], []
        arc_logits = cuda.to_cpu(self._arc_logits.data)
        label_logits = cuda.to_cpu(self._label_logits.data)

        for arc_logit, label_logit, length in \
                zip(arc_logits, np.transpose(label_logits, (0, 2, 1, 3)),
                    self._lengths):
            arc_probs = softmax2d(arc_logit[:length, :length])
            arcs = mst(arc_probs)
            label_probs = softmax2d(label_logit[np.arange(length), arcs])
            labels = np.argmax(label_probs, axis=1)
            labels[0] = ROOT
            tokens = np.arange(1, length)
            roots = np.where(labels[tokens] == ROOT)[0] + 1
            if len(roots) < 1:
                root_arc = np.where(arcs[tokens] == 0)[0] + 1
                labels[root_arc] = ROOT
            elif len(roots) > 1:
                label_probs[roots, ROOT] = 0
                new_labels = \
                    np.argmax(label_probs[roots], axis=1)
                root_arc = np.where(arcs[tokens] == 0)[0] + 1
                labels[roots] = new_labels
                labels[root_arc] = ROOT
            arcs_batch.append(arcs)
            labels_batch.append(labels)

        return arcs_batch, labels_batch
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号