LSTM.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:torch_light 作者: ne7ermore 项目源码 文件源码
def forward(self, input, hidden):
        hx, cx = hidden
        gates = F.linear(input, self.w_ih, self.b_ih) + F.linear(hx, self.w_hh, self.b_hh) # [bsz, 4*hidden_size]
        in_gate, forget_gate, cell_gate, out_gate = gates.chunk(4, 1)
        in_gate, forget_gate, out_gate = map(F.sigmoid, [in_gate, forget_gate, out_gate])
        cell_gate = F.tanh(cell_gate)

        cy = forget_gate*cx + in_gate*cell_gate
        hy = out_gate*F.tanh(cy)

        return hy, cy
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号