def segment_max(inputs, segment_ids, num_segments=None, default=0.0):
# highly optimized to decrease the amount of actual invocation of pytorch calls
# assumes that most segments have 1 or 0 elements
segment_ids, indices = torch.sort(segment_ids)
inputs = torch.index_select(inputs, 0, indices)
output = SegmentMax.apply(inputs, segment_ids, num_segments, default)
return output
评论列表
文章目录