def forward(self, input_, hx):
"""
Args:
input_: A (batch, input_size) tensor containing input
features.
hx: A tuple (h_0, c_0), which contains the initial hidden
and cell state, where the size of both states is
(batch, hidden_size).
Returns:
h_1, c_1: Tensors containing the next hidden and cell state.
"""
h_0, c_0 = hx
batch_size = h_0.size(0)
bias_batch = (self.bias.unsqueeze(0)
.expand(batch_size, *self.bias.size()))
wh_b = torch.addmm(bias_batch, h_0, self.weight_hh)
wi = torch.mm(input_, self.weight_ih)
f, i, o, g = torch.split(wh_b + wi,
split_size=self.hidden_size, dim=1)
c_1 = torch.sigmoid(f)*c_0 + torch.sigmoid(i)*torch.tanh(g)
h_1 = torch.sigmoid(o) * torch.tanh(c_1)
return h_1, c_1
评论列表
文章目录