biaffine.py 文件源码

python
阅读 39 收藏 0 点赞 0 评论 0

项目:nn4nlp-code 作者: neubig 项目源码 文件源码
def decoding(self, src_encodings):
        src_len = len(src_encodings)

        # NOTE: should transpose before calling `mst` method!
        s_arc, s_label = self.cal_scores(src_encodings)
        s_arc_values = s_arc.npvalue().transpose()  # src_len, src_len
        s_label_values = np.asarray([x.npvalue() for x in s_label]).transpose((2, 1, 0))  # src_len, src_len, n_labels

        # weights = np.zeros((src_len + 1, src_len + 1))
        # weights[0, 1:(src_len + 1)] = np.inf
        # weights[1:(src_len + 1), 0] = np.inf
        # weights[1:(src_len + 1), 1:(src_len + 1)] = s_arc_values[batch]
        weights = s_arc_values
        pred_heads = mst(weights)
        pred_labels = [np.argmax(labels[head]) for head, labels in zip(pred_heads, s_label_values)]

        return pred_heads, pred_labels
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号