cls_sparse_skip_filt.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:mss_pytorch 作者: Js-Mim 项目源码 文件源码
def forward(self, input_x):

        if torch.has_cudnn:
            # Initialization of the hidden states
            h_t_fr = Variable(torch.zeros(self._B, self._F).cuda(), requires_grad=False)
            h_t_bk = Variable(torch.zeros(self._B, self._F).cuda(), requires_grad=False)
            H_enc = Variable(torch.zeros(self._B, self._T - (2 * self._L), 2 * self._F).cuda(), requires_grad=False)

            # Input is of the shape : (B (batches), T (time-sequence), N(frequency sub-bands))
            # Cropping some "un-necessary" frequency sub-bands
            cxin = Variable(torch.pow(torch.from_numpy(input_x[:, :, :self._F]).cuda(), self._alpha))

        else:
            # Initialization of the hidden states
            h_t_fr = Variable(torch.zeros(self._B, self._F), requires_grad=False)
            h_t_bk = Variable(torch.zeros(self._B, self._F), requires_grad=False)
            H_enc = Variable(torch.zeros(self._B, self._T - (2 * self._L), 2 * self._F), requires_grad=False)

            # Input is of the shape : (B (batches), T (time-sequence), N(frequency sub-bands))
            # Cropping some "un-necessary" frequency sub-bands
            cxin = Variable(torch.pow(torch.from_numpy(input_x[:, :, :self._F]), self._alpha))

        for t in range(self._T):
            # Bi-GRU Encoding
            h_t_fr = self.gruEncF((cxin[:, t, :]), h_t_fr)
            h_t_bk = self.gruEncB((cxin[:, self._T - t - 1, :]), h_t_bk)
            # Residual connections
            h_t_fr += cxin[:, t, :]
            h_t_bk += cxin[:, self._T - t - 1, :]

            # Remove context and concatenate
            if (t >= self._L) and (t < self._T - self._L):
                h_t = torch.cat((h_t_fr, h_t_bk), dim=1)
                H_enc[:, t - self._L, :] = h_t

        return H_enc
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号