def forward(self, q, k, v, attn_mask=None):
attn = torch.bmm(q, k.transpose(1, 2)) / self.temper
if attn_mask is not None:
assert attn_mask.size() == attn.size(), \
'Attention mask shape {} mismatch ' \
'with Attention logit tensor shape ' \
'{}.'.format(attn_mask.size(), attn.size())
attn.data.masked_fill_(attn_mask, -float('inf'))
attn = self.softmax(attn)
attn = self.dropout(attn)
output = torch.bmm(attn, v)
return output, attn
Modules.py 文件源码
python
阅读 30
收藏 0
点赞 0
评论 0
评论列表
文章目录