def test_encode(self):
inputs = tf.random_normal(
[self.batch_size, self.sequence_length, self.input_depth])
example_length = tf.ones(
self.batch_size, dtype=tf.int32) * self.sequence_length
encode_fn = rnn_encoder.UnidirectionalRNNEncoder(self.params, self.mode)
encoder_output = encode_fn(inputs, example_length)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
encoder_output_ = sess.run(encoder_output)
np.testing.assert_array_equal(encoder_output_.outputs.shape,
[self.batch_size, self.sequence_length, 32])
self.assertIsInstance(encoder_output_.final_state,
tf.contrib.rnn.LSTMStateTuple)
np.testing.assert_array_equal(encoder_output_.final_state.h.shape,
[self.batch_size, 32])
np.testing.assert_array_equal(encoder_output_.final_state.c.shape,
[self.batch_size, 32])
评论列表
文章目录