segment.py 文件源码

python
阅读 52 收藏 0 点赞 0 评论 0

项目:jack 作者: uclmr 项目源码 文件源码
def segment_max(inputs, segment_ids, num_segments=None, default=0.0):
    # highly optimized to decrease the amount of actual invocation of pytorch calls
    # assumes that most segments have 1 or 0 elements
    segment_ids, indices = torch.sort(segment_ids)
    inputs = torch.index_select(inputs, 0, indices)
    output = SegmentMax.apply(inputs, segment_ids, num_segments, default)
    return output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号