def forward(self, input_, hx):
"""
Args:
input_: A (batch, input_size) tensor containing input
features.
hx: initial hidden, where the size of the state is
(batch, hidden_size).
Returns:
newh: Tensors containing the next hidden state.
"""
batch_size = hx.size(0)
Ux = torch.mm(input_, self.U)
hx = Ux + hx
newh = self._EUNN(hx=hx, thetaA=self.thetaA, thetaB=self.thetaB)
newh = self._modReLU(newh, self.bias)
return newh
评论列表
文章目录