def make_length_mask(lengths):
"""
Compute binary length mask.
lengths: Variable torch.LongTensor(batch) should be on the desired
output device.
Returns:
--------
mask: torch.ByteTensor(batch x seq_len)
"""
maxlen, batch = lengths.data.max(), len(lengths)
mask = torch.arange(0, maxlen, out=lengths.data.new()) \
.repeat(batch, 1) \
.lt(lengths.data.unsqueeze(1))
return Variable(mask, volatile=lengths.volatile)
评论列表
文章目录