def forward(self, q, k, v, attn_mask):
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
residual = q
bsz, len_q, d_model = q.size()
len_k, len_v = k.size(1), v.size(1)
def reshape(x):
"""[bsz, len, d_*] -> [n_head x (bsz*len) x d_*]"""
return x.repeat(n_head, 1, 1).view(n_head, -1, d_model)
q_s, k_s, v_s = map(reshape, [q, k, v])
q_s = torch.bmm(q_s, self.w_qs).view(-1, len_q, d_k)
k_s = torch.bmm(k_s, self.w_ks).view(-1, len_k, d_k)
v_s = torch.bmm(v_s, self.w_vs).view(-1, len_v, d_v)
outputs = self.attention(q_s, k_s, v_s, attn_mask.repeat(n_head, 1, 1))
outputs = torch.cat(torch.split(outputs, bsz, dim=0), dim=-1).view(-1, n_head*d_v)
outputs = F.dropout(self.w_o(outputs), p=self.dropout).view(bsz, len_q, -1)
return self.lm(outputs + residual)
评论列表
文章目录