def testInitializers(self):
inputs = tf.placeholder(tf.float32, shape=[self.batch_size, self.in_size])
prev_state = tf.placeholder(tf.float32,
shape=[self.batch_size, self.hidden_size])
with self.assertRaisesRegexp(KeyError, "Invalid initializer keys.*"):
snt.VanillaRNN(name="rnn",
hidden_size=self.hidden_size,
initializers={"invalid": None})
err = "Initializer for 'w' is not a callable function"
with self.assertRaisesRegexp(TypeError, err):
snt.VanillaRNN(name="rnn",
hidden_size=self.hidden_size,
initializers={"in_to_hidden": {"w": tf.zeros([10, 10])}})
# Nested initializer.
valid_initializers = {
"in_to_hidden": {
"w": tf.ones_initializer(),
},
"hidden_to_hidden": {
"b": tf.ones_initializer(),
}
}
vanilla_rnn = snt.VanillaRNN(name="rnn",
hidden_size=self.hidden_size,
initializers=valid_initializers)
vanilla_rnn(inputs, prev_state)
init = tf.global_variables_initializer()
with self.test_session() as sess:
sess.run(init)
w_v, b_v = sess.run([
vanilla_rnn.in_to_hidden_linear.w,
vanilla_rnn.hidden_to_hidden_linear.b,
])
self.assertAllClose(w_v, np.ones([self.in_size, self.hidden_size]))
self.assertAllClose(b_v, np.ones([self.hidden_size]))
评论列表
文章目录