Modules.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:attention-is-all-you-need-pytorch 作者: jadore801120 项目源码 文件源码
def forward(self, q, k, v, attn_mask=None):

        attn = torch.bmm(q, k.transpose(1, 2)) / self.temper

        if attn_mask is not None:

            assert attn_mask.size() == attn.size(), \
                    'Attention mask shape {} mismatch ' \
                    'with Attention logit tensor shape ' \
                    '{}.'.format(attn_mask.size(), attn.size())

            attn.data.masked_fill_(attn_mask, -float('inf'))

        attn = self.softmax(attn)
        attn = self.dropout(attn)
        output = torch.bmm(attn, v)

        return output, attn
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号