def get_attn_subsequent_mask(seq):
''' Get an attention mask to avoid using the subsequent info.'''
assert seq.dim() == 2
attn_shape = (seq.size(0), seq.size(1), seq.size(1))
subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
subsequent_mask = torch.from_numpy(subsequent_mask)
if seq.is_cuda:
subsequent_mask = subsequent_mask.cuda()
return subsequent_mask
Models.py 文件源码
python
阅读 34
收藏 0
点赞 0
评论 0
评论列表
文章目录