rnn_core_test.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:sonnet 作者: deepmind 项目源码 文件源码
def testInitialStateTuple(self, trainable, use_custom_initial_value,
                            state_size):
    batch_size = 6

    # Set the attribute to the class since it we can't set properties of
    # abstract classes
    snt.RNNCore.state_size = state_size
    flat_state_size = nest.flatten(state_size)
    core = snt.RNNCore(name="dummy_core")
    if use_custom_initial_value:
      flat_initializer = [tf.constant_initializer(2)] * len(flat_state_size)
      trainable_initializers = nest.pack_sequence_as(
          structure=state_size, flat_sequence=flat_initializer)
    else:
      trainable_initializers = None
    initial_state = core.initial_state(
        batch_size, dtype=tf.float32, trainable=trainable,
        trainable_initializers=trainable_initializers)

    nest.assert_same_structure(initial_state, state_size)
    flat_initial_state = nest.flatten(initial_state)

    for state, size in zip(flat_initial_state, flat_state_size):
      self.assertEqual(state.get_shape(), [batch_size, size])

    with self.test_session() as sess:
      tf.global_variables_initializer().run()
      flat_initial_state_value = sess.run(flat_initial_state)
      for value, size in zip(flat_initial_state_value, flat_state_size):
        expected_initial_state = np.empty([batch_size, size])
        if not trainable:
          expected_initial_state.fill(0)
        elif use_custom_initial_value:
          expected_initial_state.fill(2)
        else:
          value_row = value[0]
          expected_initial_state = np.tile(value_row, (batch_size, 1))
        self.assertAllClose(value, expected_initial_state)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号