def make_positions(tokens, padding_idx, left_pad, offset=0):
seqlen = tokens.size(1)
if not hasattr(make_positions, 'range'):
make_positions.range = tokens.new()
if make_positions.range.numel() < offset + seqlen:
# offset positions by the padding index
torch.arange(padding_idx + 1, padding_idx + 1 + offset + seqlen,
out=make_positions.range)
mask = tokens.ne(padding_idx)
positions = make_positions.range[offset:offset+seqlen].expand_as(tokens)
if left_pad:
positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1)
return tokens.clone().masked_scatter_(mask, positions[mask])
评论列表
文章目录