lstm.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:chainer-deconv 作者: germanRos 项目源码 文件源码
def forward(self, inputs):
        c_prev, x = inputs
        a, i, f, o = _extract_gates(x)

        if isinstance(x, numpy.ndarray):
            self.a = numpy.tanh(a)
            self.i = _sigmoid(i)
            self.f = _sigmoid(f)
            self.o = _sigmoid(o)

            self.c = self.a * self.i + self.f * c_prev
            h = self.o * numpy.tanh(self.c)
        else:
            self.c, h = cuda.elementwise(
                'T c_prev, T a, T i_, T f, T o', 'T c, T h',
                '''
                    COMMON_ROUTINE;
                    c = aa * ai + af * c_prev;
                    h = ao * tanh(c);
                ''',
                'lstm_fwd', preamble=_preamble)(c_prev, a, i, f, o)

        return self.c, h
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号