def parse(self, pretrained_word_tokens=None,
word_tokens=None, pos_tokens=None):
if word_tokens is not None:
self.forward(pretrained_word_tokens, word_tokens, pos_tokens)
ROOT = self._ROOT_LABEL
arcs_batch, labels_batch = [], []
arc_logits = cuda.to_cpu(self._arc_logits.data)
label_logits = cuda.to_cpu(self._label_logits.data)
for arc_logit, label_logit, length in \
zip(arc_logits, np.transpose(label_logits, (0, 2, 1, 3)),
self._lengths):
arc_probs = softmax2d(arc_logit[:length, :length])
arcs = mst(arc_probs)
label_probs = softmax2d(label_logit[np.arange(length), arcs])
labels = np.argmax(label_probs, axis=1)
labels[0] = ROOT
tokens = np.arange(1, length)
roots = np.where(labels[tokens] == ROOT)[0] + 1
if len(roots) < 1:
root_arc = np.where(arcs[tokens] == 0)[0] + 1
labels[root_arc] = ROOT
elif len(roots) > 1:
label_probs[roots, ROOT] = 0
new_labels = \
np.argmax(label_probs[roots], axis=1)
root_arc = np.where(arcs[tokens] == 0)[0] + 1
labels[roots] = new_labels
labels[root_arc] = ROOT
arcs_batch.append(arcs)
labels_batch.append(labels)
return arcs_batch, labels_batch
评论列表
文章目录