coref.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:allennlp 作者: allenai 项目源码 文件源码
def _prune_and_sort_spans(mention_scores: torch.FloatTensor,
                              num_spans_to_keep: int) -> torch.IntTensor:
        """
        The indices of the top-k scoring spans according to span_scores. We return the
        indices in their original order, not ordered by score, so that we can rely on
        the ordering to consider the previous k spans as antecedents for each span later.

        Parameters
        ----------
        mention_scores : ``torch.FloatTensor``, required.
            The mention score for every candidate, with shape (batch_size, num_spans, 1).
        num_spans_to_keep : ``int``, required.
            The number of spans to keep when pruning.
        Returns
        -------
        top_span_indices : ``torch.IntTensor``, required.
            The indices of the top-k scoring spans. Has shape (batch_size, num_spans_to_keep).
        """
        # Shape: (batch_size, num_spans_to_keep, 1)
        _, top_span_indices = mention_scores.topk(num_spans_to_keep, 1)
        top_span_indices, _ = torch.sort(top_span_indices, 1)

        # Shape: (batch_size, num_spans_to_keep)
        top_span_indices = top_span_indices.squeeze(-1)
        return top_span_indices
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号