def forward(self, input_x):
if torch.has_cudnn:
# Initialization of the hidden states
h_t_fr = Variable(torch.zeros(self._B, self._F).cuda(), requires_grad=False)
h_t_bk = Variable(torch.zeros(self._B, self._F).cuda(), requires_grad=False)
H_enc = Variable(torch.zeros(self._B, self._T - (2 * self._L), 2 * self._F).cuda(), requires_grad=False)
# Input is of the shape : (B (batches), T (time-sequence), N(frequency sub-bands))
# Cropping some "un-necessary" frequency sub-bands
cxin = Variable(torch.pow(torch.from_numpy(input_x[:, :, :self._F]).cuda(), self._alpha))
else:
# Initialization of the hidden states
h_t_fr = Variable(torch.zeros(self._B, self._F), requires_grad=False)
h_t_bk = Variable(torch.zeros(self._B, self._F), requires_grad=False)
H_enc = Variable(torch.zeros(self._B, self._T - (2 * self._L), 2 * self._F), requires_grad=False)
# Input is of the shape : (B (batches), T (time-sequence), N(frequency sub-bands))
# Cropping some "un-necessary" frequency sub-bands
cxin = Variable(torch.pow(torch.from_numpy(input_x[:, :, :self._F]), self._alpha))
for t in range(self._T):
# Bi-GRU Encoding
h_t_fr = self.gruEncF((cxin[:, t, :]), h_t_fr)
h_t_bk = self.gruEncB((cxin[:, self._T - t - 1, :]), h_t_bk)
# Residual connections
h_t_fr += cxin[:, t, :]
h_t_bk += cxin[:, self._T - t - 1, :]
# Remove context and concatenate
if (t >= self._L) and (t < self._T - self._L):
h_t = torch.cat((h_t_fr, h_t_bk), dim=1)
H_enc[:, t - self._L, :] = h_t
return H_enc
评论列表
文章目录