def test_without_residuals(self):
inputs = tf.constant(np.random.randn(1, 2))
state = (tf.constant(np.random.randn(1, 2)),
tf.constant(np.random.randn(1, 2)))
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
standard_cell = tf.contrib.rnn.MultiRNNCell(
[tf.contrib.rnn.GRUCell(2) for _ in range(2)], state_is_tuple=True)
res_standard = standard_cell(inputs, state, scope="standard")
test_cell = rnn_cell.ExtendedMultiRNNCell(
[tf.contrib.rnn.GRUCell(2) for _ in range(2)])
res_test = test_cell(inputs, state, scope="test")
with self.test_session() as sess:
sess.run([tf.global_variables_initializer()])
res_standard_, res_test_, = sess.run([res_standard, res_test])
# Make sure it produces the same results as the standard cell
self.assertAllClose(res_standard_[0], res_test_[0])
self.assertAllClose(res_standard_[1][0], res_test_[1][0])
self.assertAllClose(res_standard_[1][1], res_test_[1][1])
评论列表
文章目录