cls_sparse_skip_filt.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:mss_pytorch 作者: Js-Mim 项目源码 文件源码
def forward(self, H_enc):
        if torch.has_cudnn:
            # Initialization of the hidden states
            h_t_dec = Variable(torch.zeros(self._B, self._gruout).cuda(), requires_grad=False)

            # Initialization of the decoder output
            H_j_dec = Variable(torch.zeros(self._B, self._T - (self._L * 2), self._gruout).cuda(), requires_grad=False)

        else:
            # Initialization of the hidden states
            h_t_dec = Variable(torch.zeros(self._B, self._gruout), requires_grad=False)

            # Initialization of the decoder output
            H_j_dec = Variable(torch.zeros(self._B, self._T - (self._L * 2), self._gruout), requires_grad=False)

        for ts in range(self._T - (self._L * 2)):
            # GRU Decoding
            h_t_dec = self.gruDec(H_enc[:, ts, :], h_t_dec)
            H_j_dec[:, ts, :] = h_t_dec

        return H_j_dec
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号