def check_forward(self, c_prev_data, x_data):
c_prev = chainer.Variable(c_prev_data)
x = chainer.Variable(x_data)
c, h = functions.lstm(c_prev, x)
self.assertEqual(c.data.dtype, self.dtype)
self.assertEqual(h.data.dtype, self.dtype)
# Compute expected out
a_in = self.x[:, [0, 4]]
i_in = self.x[:, [1, 5]]
f_in = self.x[:, [2, 6]]
o_in = self.x[:, [3, 7]]
c_expect = _sigmoid(i_in) * numpy.tanh(a_in) + \
_sigmoid(f_in) * self.c_prev
h_expect = _sigmoid(o_in) * numpy.tanh(c_expect)
gradient_check.assert_allclose(
c_expect, c.data, **self.check_forward_options)
gradient_check.assert_allclose(
h_expect, h.data, **self.check_forward_options)
评论列表
文章目录