def forward(self, input_enc, input_attW_enc, input_dec, lengths_enc, hidden_att=None, hidden_dec1=None, hidden_dec2=None):
N = input_dec.size(0)
out_att = self.prenet(input_dec).unsqueeze(1) # N x O_dec -> N x 1 x H
out_att, hidden_att = self.gru_att(out_att, hidden_att) # N x 1 x 2H
in_attW_dec = self.linear_dec(out_att.squeeze(1)).unsqueeze(1).expand_as(input_enc)
in_attW_dec = rnn.pack_padded_sequence(in_attW_dec, lengths_enc, True) # N*T_enc x 2H
self.attn_weights = torch.add(input_attW_enc, in_attW_dec.data).tanh() # N x T_enc x 2H
self.attn_weights = self.attn(self.attn_weights).exp() # N*T_enc x 1
self.attn_weights = rnn.PackedSequence(self.attn_weights, in_attW_dec.batch_sizes)
self.attn_weights, _ = rnn.pad_packed_sequence(self.attn_weights, True)
self.attn_weights = F.normalize(self.attn_weights, 1, 1) # N x T_enc x 1
attn_applied = torch.bmm(self.attn_weights.transpose(1,2), input_enc) # N x 1 x 2H
out_dec = torch.cat((attn_applied, out_att), 2) # N x 1 x 4H
residual = self.short_cut(out_dec.squeeze(1)).unsqueeze(1) # N x 1 x 2H
out_dec, hidden_dec1 = self.gru_dec1(out_dec, hidden_dec1)
residual = residual + out_dec
out_dec, hidden_dec2 = self.gru_dec2(residual, hidden_dec2)
residual = residual + out_dec
output = self.out(residual.squeeze(1)).view(N, self.r_factor, -1)
return output, hidden_att, hidden_dec1, hidden_dec2
评论列表
文章目录