entailment_graph.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:OKR 作者: vered1986 项目源码 文件源码
def compute_entities_f1(gold_graph, pred_graph):
    """
    Compute the agreement for the entity entailment graph, for each entity, and return the average
    :param gold_graph: the first annotator's graph
    :param pred_graph: the second annotator's graph
    :return: the entity edges' mean F1 score
    """

    # Get all the possible edges in the entity entailment graph
    all_edges = {str(entity): set([(str(m1), str(m2))
                                   for m1 in entity.mentions.values()
                                   for m2 in entity.mentions.values() if m1 != m2])
                 for entity in gold_graph.entities.values() if len(entity.mentions) > 1}

    # Get the binary predictions/gold for these edges
    str_entities_gold = { entity : str(entity) for entity in gold_graph.entities.values() }
    entity_entailments_gold = {str_entities_gold[entity]:
                                [1 if (m1, m2) in set(entity.entailment_graph.mentions_graph) else 0
                                 for (m1, m2) in all_edges[str_entities_gold[entity]]]
                            for entity in gold_graph.entities.values() if str_entities_gold[entity] in all_edges.keys()}

    str_entities_pred = { entity : str(entity) for entity in pred_graph.entities.values() }
    entity_entailments_pred = {str_entities_pred[entity]:
                                [1 if (m1, m2) in set(entity.entailment_graph.mentions_graph) else 0
                                 for (m1, m2) in all_edges[str_entities_pred[entity]]]
                            for entity in pred_graph.entities.values() if str_entities_pred[entity] in all_edges.keys()}

    mutual_entities = list(set(entity_entailments_gold.keys()).intersection(entity_entailments_pred.keys()))

    # If both graphs contain no entailments, the score should be one
    f1 = np.mean([precision_recall_fscore_support(entity_entailments_gold[entity], entity_entailments_pred[entity],
                                                  average='binary')[2]
                  if np.sum(entity_entailments_gold[entity]) > 0 or np.sum(entity_entailments_pred[entity]) > 0 else 1.0
                  for entity in mutual_entities])

    return f1
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号