model.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:Tacotron_pytorch 作者: root20 项目源码 文件源码
def forward(self, input, lengths):
        N, T = input.size(0), input.size(1)

        conv_bank_out = []
        input_t = input.transpose(1, 2)  # NxTxH -> NxHxT
        for i in range(self.num_filters):
            tmp_input = input_t
            if i % 2 == 0:
                tmp_input = tmp_input.unsqueeze(-1)
                tmp_input = F.pad(tmp_input, (0,0,0,1)).squeeze(-1)   # NxHxT
            conv_bank_out.append(self.conv_bank[i](tmp_input))

        residual = torch.cat(conv_bank_out, dim=1)                  # NxHFxT
        residual = F.relu(self.bn_list[0](residual))
        residual = F.max_pool1d(residual, 2, stride=1)
        residual = self.conv1(residual)                             # NxHxT
        residual = F.relu(self.bn_list[1](residual))
        residual = self.conv2(residual)                             # NxHxT
        residual = self.bn_list[2](residual).transpose(1,2)         # NxHxT -> NxTxH

        rnn_input = input
        if rnn_input.size() != residual.size():
            rnn_input = self.residual_proj(rnn_input)
        rnn_input = rnn_input + residual
        rnn_input = self.highway(rnn_input).view(N, T, -1)

        output = rnn.pack_padded_sequence(rnn_input, lengths, True)
        output, _ = self.BGRU(output)                               # zero h_0 is used by default
        output, _ = rnn.pad_packed_sequence(output, True)           # NxTx2H
        return output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号