def evaluate(self, data: List[ContextAndQuestion], true_len, **kwargs):
best_spans = kwargs["span"]
span_logits = kwargs["score"]
if self.eval == "triviaqa":
scores = trivia_span_scores(data, best_spans)
elif self.eval == "squad":
scores = squad_span_scores(data, best_spans)
else:
raise RuntimeError()
has_answer = np.array([len(x.answer.answer_spans) > 0 for x in data])
selected_paragraphs = {}
for i, point in enumerate(data):
if self.per_doc:
key = (point.question_id, point.doc_id)
else:
key = point.question_id
if key not in selected_paragraphs:
selected_paragraphs[key] = i
elif span_logits[i] > span_logits[selected_paragraphs[key]]:
selected_paragraphs[key] = i
selected_paragraphs = list(selected_paragraphs.values())
out = {
"question-text-em": scores[selected_paragraphs, 2].mean(),
"question-text-f1": scores[selected_paragraphs, 3].mean(),
}
if self.k_tau:
out["text-em-k-tau"] = kendalltau(span_logits, scores[:, 2])[0]
out["text-f1-k-tau"] = kendalltau(span_logits, scores[:, 3])[0]
if self.paragraph_level:
out["paragraph-text-em"] = scores[has_answer, 2].mean()
out["paragraph-text-f1"] = scores[has_answer, 3].mean()
prefix = "b%d/" % self.bound
return Evaluation({prefix+k: v for k,v in out.items()})
评论列表
文章目录