loader.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:Tacotron_pytorch 作者: root20 项目源码 文件源码
def mask_prev_h(self, prev_h):
        if self.len_wave_mask is not None:
            if self.use_gpu:
                self.len_wave_mask = self.len_wave_mask.cuda()

            h_att, h_dec1, h_dec2 = prev_h
            h_att = torch.index_select(h_att.data, 1, self.len_wave_mask)  # batch idx is
            h_dec1 = torch.index_select(h_dec1.data, 1, self.len_wave_mask)
            h_dec2 = torch.index_select(h_dec2.data, 1, self.len_wave_mask)
            prev_h = (Variable(h_att), Variable(h_dec1), Variable(h_dec2))
        else:
            prev_h = prev_h

        return prev_h
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号