def sequence_mask(lens, max_len=None):
batch_size = lens.size(0)
if max_len is None:
max_len = lens.max().data[0]
ranges = torch.arange(0, max_len).long()
ranges = ranges.unsqueeze(0).expand(batch_size, max_len)
ranges = Variable(ranges)
if lens.data.is_cuda:
ranges = ranges.cuda()
lens_exp = lens.unsqueeze(1).expand_as(ranges)
mask = ranges < lens_exp
return mask
评论列表
文章目录