def testComputation(self):
model_rnn = snt.ModelRNN(self.model)
inputs = tf.random_normal([self.batch_size, 5])
prev_state = tf.placeholder(tf.float32,
shape=[self.batch_size, self.hidden_size])
outputs, next_state = model_rnn(inputs, prev_state)
with self.test_session() as sess:
prev_state_data = np.random.randn(self.batch_size, self.hidden_size)
feed_dict = {prev_state: prev_state_data}
sess.run(tf.global_variables_initializer())
outputs_value = sess.run([outputs, next_state], feed_dict=feed_dict)
outputs_value, next_state_value = outputs_value
self.assertAllClose(prev_state_data, outputs_value)
self.assertAllClose(outputs_value, next_state_value)
评论列表
文章目录