def forward(self, xt, state):
all_input_sums = self.i2h(xt) + self.h2h(state[0][-1])
sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size)
sigmoid_chunk = F.sigmoid(sigmoid_chunk)
in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size)
forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size)
out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size)
in_transform = torch.max(\
all_input_sums.narrow(1, 3 * self.rnn_size, self.rnn_size),
all_input_sums.narrow(1, 4 * self.rnn_size, self.rnn_size))
next_c = forget_gate * state[1][-1] + in_gate * in_transform
next_h = out_gate * F.tanh(next_c)
next_h = self.dropout(next_h)
output = next_h
state = (next_h.unsqueeze(0), next_c.unsqueeze(0))
return output, state
评论列表
文章目录