def evaluate(self, data: List[ContextAndQuestion], true_len, **kargs):
if self.text_eval == "triviaqa":
scores = trivia_span_scores(data, kargs["spans"])
elif self.text_eval == "squad":
scores = squad_span_scores(data, kargs["spans"])
else:
raise RuntimeError()
has_answer = [len(x.answer.answer_spans) > 0 for x in data]
aggregated_scores = scores[has_answer].mean(axis=0)
prefix ="b%d/" % self.bound
scalars = {
prefix + "accuracy": aggregated_scores[0],
prefix + "f1": aggregated_scores[1],
prefix + "text-accuracy": aggregated_scores[2],
prefix + "text-f1": aggregated_scores[3]
}
if self.rank_metric == "spr":
metric = spearmanr
elif self.rank_metric == "k-tau":
metric = kendalltau
else:
raise ValueError()
if "none_prob" in kargs:
none_conf = kargs["none_prob"]
scalars[prefix + "none-text-f1-" + self.rank_metric] = metric(none_conf, scores[:, 3])[0]
scalars[prefix + "none-span-accuracy-" + self.rank_metric] = metric(none_conf, scores[:, 0])[0]
conf = kargs["conf"]
scalars[prefix + "score-text-f1-" + self.rank_metric] = metric(conf, scores[:, 3])[0]
scalars[prefix + "score-span-accuracy-" + self.rank_metric] = metric(conf, scores[:, 0])[0]
return Evaluation(scalars)
评论列表
文章目录