def testConstructRNN(self):
initial_state = None
sequence_input = dynamic_rnn_estimator.build_sequence_input(
self.columns_to_tensors,
self.sequence_feature_columns,
self.context_feature_columns)
activations_t, final_state_t = dynamic_rnn_estimator.construct_rnn(
initial_state,
sequence_input,
self.rnn_cell,
self.mock_target_column.num_label_columns)
# Obtain values of activations and final state.
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.initialize_all_tables())
activations, final_state = sess.run([activations_t, final_state_t])
expected_activations_shape = np.array([3, 2, self.NUM_LABEL_COLUMNS])
self.assertAllEqual(expected_activations_shape, activations.shape)
expected_state_shape = np.array([3, self.NUM_RNN_CELL_UNITS])
self.assertAllEqual(expected_state_shape, final_state.shape)
评论列表
文章目录