def forward(self, input, hidden):
hx, cx = hidden
gates = F.linear(input, self.w_ih, self.b_ih) + F.linear(hx, self.w_hh, self.b_hh) # [bsz, 4*hidden_size]
in_gate, forget_gate, cell_gate, out_gate = gates.chunk(4, 1)
in_gate, forget_gate, out_gate = map(F.sigmoid, [in_gate, forget_gate, out_gate])
cell_gate = F.tanh(cell_gate)
cy = forget_gate*cx + in_gate*cell_gate
hy = out_gate*F.tanh(cy)
return hy, cy
评论列表
文章目录