def forward(self, input_, hx, time):
"""
Args:
input_: A (batch, input_size) tensor containing input
features.
hx: A tuple (h_0, c_0), which contains the initial hidden
and cell state, where the size of both states is
(batch, hidden_size).
time: The current timestep value, which is used to
get appropriate running statistics.
Returns:
h_1, c_1: Tensors containing the next hidden and cell state.
"""
h_0, c_0 = hx
batch_size = h_0.size(0)
bias_batch = (self.bias.unsqueeze(0)
.expand(batch_size, *self.bias.size()))
wh = torch.mm(h_0, self.weight_hh)
wi = torch.mm(input_, self.weight_ih)
bn_wh = self.bn_hh(wh, time=time)
bn_wi = self.bn_ih(wi, time=time)
f, i, o, g = torch.split(bn_wh + bn_wi + bias_batch,
split_size=self.hidden_size, dim=1)
c_1 = torch.sigmoid(f)*c_0 + torch.sigmoid(i)*torch.tanh(g)
h_1 = torch.sigmoid(o) * torch.tanh(self.bn_c(c_1, time=time))
return h_1, c_1
评论列表
文章目录