def forward(self, input_, hx):
"""
Args:
input_: A (batch, input_size) tensor containing input
features.
hx: initial hidden, where the size of the state is
(batch, hidden_size).
Returns:
newh: Tensors containing the next hidden state.
"""
batch_size = hx.size(0)
bias_batch = (self.gate_bias.unsqueeze(0)
.expand(batch_size, *self.gate_bias.size()))
gate_Wh = torch.addmm(bias_batch, hx, self.gate_W)
gate_Ux = torch.mm(input_, self.gate_U)
r, z = torch.split(gate_Ux + gate_Wh,
split_size=self.hidden_size, dim=1)
Ux = torch.mm(input_, self.U)
unitary = self._EUNN(hx=hx, thetaA=self.thetaA, thetaB=self.thetaB)
unitary = unitary * r
newh = Ux + unitary
newh = self._modReLU(newh, self.bias)
newh = hx * z + (1-z) * newh
return newh
评论列表
文章目录