utils.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:seqmod 作者: emanjavacas 项目源码 文件源码
def make_length_mask(lengths):
    """
    Compute binary length mask.

    lengths: Variable torch.LongTensor(batch) should be on the desired
        output device.

    Returns:
    --------

    mask: torch.ByteTensor(batch x seq_len)
    """
    maxlen, batch = lengths.data.max(), len(lengths)
    mask = torch.arange(0, maxlen, out=lengths.data.new()) \
                .repeat(batch, 1) \
                .lt(lengths.data.unsqueeze(1))
    return Variable(mask, volatile=lengths.volatile)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号