def forward(self, H_enc):
if torch.has_cudnn:
# Initialization of the hidden states
h_t_dec = Variable(torch.zeros(self._B, self._gruout).cuda(), requires_grad=False)
# Initialization of the decoder output
H_j_dec = Variable(torch.zeros(self._B, self._T - (self._L * 2), self._gruout).cuda(), requires_grad=False)
else:
# Initialization of the hidden states
h_t_dec = Variable(torch.zeros(self._B, self._gruout), requires_grad=False)
# Initialization of the decoder output
H_j_dec = Variable(torch.zeros(self._B, self._T - (self._L * 2), self._gruout), requires_grad=False)
for ts in range(self._T - (self._L * 2)):
# GRU Decoding
h_t_dec = self.gruDec(H_enc[:, ts, :], h_t_dec)
H_j_dec[:, ts, :] = h_t_dec
return H_j_dec
评论列表
文章目录