model.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:Tacotron_pytorch 作者: root20 项目源码 文件源码
def forward(self, input_enc, input_attW_enc, input_dec, lengths_enc, hidden_att=None, hidden_dec1=None, hidden_dec2=None):
        N = input_dec.size(0)

        out_att = self.prenet(input_dec).unsqueeze(1)                                   # N x O_dec -> N x 1 x H
        out_att, hidden_att = self.gru_att(out_att, hidden_att)                         # N x 1 x 2H
        in_attW_dec = self.linear_dec(out_att.squeeze(1)).unsqueeze(1).expand_as(input_enc)
        in_attW_dec = rnn.pack_padded_sequence(in_attW_dec, lengths_enc, True)          # N*T_enc x 2H

        self.attn_weights = torch.add(input_attW_enc, in_attW_dec.data).tanh()          # N x T_enc x 2H
        self.attn_weights = self.attn(self.attn_weights).exp()                          # N*T_enc x 1
        self.attn_weights = rnn.PackedSequence(self.attn_weights, in_attW_dec.batch_sizes)
        self.attn_weights, _ = rnn.pad_packed_sequence(self.attn_weights, True)
        self.attn_weights = F.normalize(self.attn_weights, 1, 1)                        # N x T_enc x 1

        attn_applied = torch.bmm(self.attn_weights.transpose(1,2), input_enc)           # N x 1 x 2H

        out_dec = torch.cat((attn_applied, out_att), 2)                                 # N x 1 x 4H
        residual = self.short_cut(out_dec.squeeze(1)).unsqueeze(1)                      # N x 1 x 2H

        out_dec, hidden_dec1 = self.gru_dec1(out_dec, hidden_dec1)
        residual = residual + out_dec

        out_dec, hidden_dec2 = self.gru_dec2(residual, hidden_dec2)
        residual = residual + out_dec

        output = self.out(residual.squeeze(1)).view(N, self.r_factor, -1)
        return output, hidden_att, hidden_dec1, hidden_dec2
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号