def testInitialStateTuple(self, trainable, use_custom_initial_value,
state_size):
batch_size = 6
# Set the attribute to the class since it we can't set properties of
# abstract classes
snt.RNNCore.state_size = state_size
flat_state_size = nest.flatten(state_size)
core = snt.RNNCore(name="dummy_core")
if use_custom_initial_value:
flat_initializer = [tf.constant_initializer(2)] * len(flat_state_size)
trainable_initializers = nest.pack_sequence_as(
structure=state_size, flat_sequence=flat_initializer)
else:
trainable_initializers = None
initial_state = core.initial_state(
batch_size, dtype=tf.float32, trainable=trainable,
trainable_initializers=trainable_initializers)
nest.assert_same_structure(initial_state, state_size)
flat_initial_state = nest.flatten(initial_state)
for state, size in zip(flat_initial_state, flat_state_size):
self.assertEqual(state.get_shape(), [batch_size, size])
with self.test_session() as sess:
tf.global_variables_initializer().run()
flat_initial_state_value = sess.run(flat_initial_state)
for value, size in zip(flat_initial_state_value, flat_state_size):
expected_initial_state = np.empty([batch_size, size])
if not trainable:
expected_initial_state.fill(0)
elif use_custom_initial_value:
expected_initial_state.fill(2)
else:
value_row = value[0]
expected_initial_state = np.tile(value_row, (batch_size, 1))
self.assertAllClose(value, expected_initial_state)
评论列表
文章目录