def forward(self, dec_state, context, mask=None):
"""
:param dec_state: batch x dec_dim
:param context: batch x T x enc_dim
:return: Weighted context, batch x enc_dim
Alpha weights (viz), batch x T
"""
batch, source_l, enc_dim = context.size()
assert enc_dim == self.enc_dim
# W*s over the entire batch (batch, attn_dim)
dec_contrib = self.decoder_in(dec_state)
# W*h over the entire length & batch (batch, source_l, attn_dim)
enc_contribs = self.encoder_in(
context.view(-1, self.enc_dim)).view(batch, source_l, self.attn_dim)
# tanh( Wh*hj + Ws s_{i-1} ) (batch, source_l, dim)
pre_attn = F.tanh(enc_contribs + dec_contrib.unsqueeze(1).expand_as(enc_contribs))
# v^T*pre_attn for all batches/lengths (batch, source_l)
energy = self.att_linear(pre_attn.view(-1, self.attn_dim)).view(batch, source_l)
# Apply the mask. (Might be a better way to do this)
if mask is not None:
shift = energy.max(1)[0]
energy_exp = (energy - shift.expand_as(energy)).exp() * mask
alpha = torch.div(energy_exp, energy_exp.sum(1).expand_as(energy_exp))
else:
alpha = F.softmax(energy)
weighted_context = torch.bmm(alpha.unsqueeze(1), context).squeeze(1) # (batch, dim)
return weighted_context, alpha
评论列表
文章目录