def _prune_and_sort_spans(mention_scores: torch.FloatTensor,
num_spans_to_keep: int) -> torch.IntTensor:
"""
The indices of the top-k scoring spans according to span_scores. We return the
indices in their original order, not ordered by score, so that we can rely on
the ordering to consider the previous k spans as antecedents for each span later.
Parameters
----------
mention_scores : ``torch.FloatTensor``, required.
The mention score for every candidate, with shape (batch_size, num_spans, 1).
num_spans_to_keep : ``int``, required.
The number of spans to keep when pruning.
Returns
-------
top_span_indices : ``torch.IntTensor``, required.
The indices of the top-k scoring spans. Has shape (batch_size, num_spans_to_keep).
"""
# Shape: (batch_size, num_spans_to_keep, 1)
_, top_span_indices = mention_scores.topk(num_spans_to_keep, 1)
top_span_indices, _ = torch.sort(top_span_indices, 1)
# Shape: (batch_size, num_spans_to_keep)
top_span_indices = top_span_indices.squeeze(-1)
return top_span_indices
评论列表
文章目录