def forward(self, q, k, v, attn_mask):
attn = torch.bmm(q, k.transpose(1, 2)) / self.temper
attn.data.masked_fill_(attn_mask, -float('inf'))
attn = self.softmax(attn.view(-1, attn.size(2))).view(*attn.size())
attn = self.dropout(attn)
return torch.bmm(attn, v)
评论列表
文章目录