def decoding(self, src_encodings):
src_len = len(src_encodings)
# NOTE: should transpose before calling `mst` method!
s_arc, s_label = self.cal_scores(src_encodings)
s_arc_values = s_arc.npvalue().transpose() # src_len, src_len
s_label_values = np.asarray([x.npvalue() for x in s_label]).transpose((2, 1, 0)) # src_len, src_len, n_labels
# weights = np.zeros((src_len + 1, src_len + 1))
# weights[0, 1:(src_len + 1)] = np.inf
# weights[1:(src_len + 1), 0] = np.inf
# weights[1:(src_len + 1), 1:(src_len + 1)] = s_arc_values[batch]
weights = s_arc_values
pred_heads = mst(weights)
pred_labels = [np.argmax(labels[head]) for head, labels in zip(pred_heads, s_label_values)]
return pred_heads, pred_labels
评论列表
文章目录