test_lstm.py 文件源码

python
阅读 38 收藏 0 点赞 0 评论 0

项目:chainer-deconv 作者: germanRos 项目源码 文件源码
def check_forward(self, c_prev_data, x_data):
        c_prev = chainer.Variable(c_prev_data)
        x = chainer.Variable(x_data)
        c, h = functions.lstm(c_prev, x)
        self.assertEqual(c.data.dtype, self.dtype)
        self.assertEqual(h.data.dtype, self.dtype)

        # Compute expected out
        a_in = self.x[:, [0, 4]]
        i_in = self.x[:, [1, 5]]
        f_in = self.x[:, [2, 6]]
        o_in = self.x[:, [3, 7]]

        c_expect = _sigmoid(i_in) * numpy.tanh(a_in) + \
            _sigmoid(f_in) * self.c_prev
        h_expect = _sigmoid(o_in) * numpy.tanh(c_expect)

        gradient_check.assert_allclose(
            c_expect, c.data, **self.check_forward_options)
        gradient_check.assert_allclose(
            h_expect, h.data, **self.check_forward_options)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号