misc.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:jack 作者: uclmr 项目源码 文件源码
def mask_for_lengths(length, max_length=None, mask_right=True, value=-1e6):
    max_length = max_length or length.max().data[0]
    mask = torch.cuda.IntTensor() if length.is_cuda else torch.IntTensor()
    mask = torch.arange(0, max_length, 1, out=mask)
    mask = torch.autograd.Variable(mask).type_as(length)
    mask /= length.unsqueeze(1)
    mask = mask.clamp(0, 1)
    mask = mask.float()
    if not mask_right:
        mask = 1.0 - mask
    mask *= value
    return mask
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号