basic_rnn_test.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:sonnet 作者: deepmind 项目源码 文件源码
def testInitializers(self):
    inputs = tf.placeholder(tf.float32, shape=[self.batch_size, self.in_size])
    prev_state = tf.placeholder(tf.float32,
                                shape=[self.batch_size, self.hidden_size])

    with self.assertRaisesRegexp(KeyError, "Invalid initializer keys.*"):
      snt.VanillaRNN(name="rnn",
                     hidden_size=self.hidden_size,
                     initializers={"invalid": None})

    err = "Initializer for 'w' is not a callable function"
    with self.assertRaisesRegexp(TypeError, err):
      snt.VanillaRNN(name="rnn",
                     hidden_size=self.hidden_size,
                     initializers={"in_to_hidden": {"w": tf.zeros([10, 10])}})

    # Nested initializer.
    valid_initializers = {
        "in_to_hidden": {
            "w": tf.ones_initializer(),
        },
        "hidden_to_hidden": {
            "b": tf.ones_initializer(),
        }
    }

    vanilla_rnn = snt.VanillaRNN(name="rnn",
                                 hidden_size=self.hidden_size,
                                 initializers=valid_initializers)

    vanilla_rnn(inputs, prev_state)
    init = tf.global_variables_initializer()

    with self.test_session() as sess:
      sess.run(init)
      w_v, b_v = sess.run([
          vanilla_rnn.in_to_hidden_linear.w,
          vanilla_rnn.hidden_to_hidden_linear.b,
      ])
      self.assertAllClose(w_v, np.ones([self.in_size, self.hidden_size]))
      self.assertAllClose(b_v, np.ones([self.hidden_size]))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号