dynamic_rnn_estimator_test.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def testConstructRNN(self):
    initial_state = None
    sequence_input = dynamic_rnn_estimator.build_sequence_input(
        self.columns_to_tensors,
        self.sequence_feature_columns,
        self.context_feature_columns)
    activations_t, final_state_t = dynamic_rnn_estimator.construct_rnn(
        initial_state,
        sequence_input,
        self.rnn_cell,
        self.mock_target_column.num_label_columns)

    # Obtain values of activations and final state.
    with tf.Session() as sess:
      sess.run(tf.global_variables_initializer())
      sess.run(tf.initialize_all_tables())
      activations, final_state = sess.run([activations_t, final_state_t])

    expected_activations_shape = np.array([3, 2, self.NUM_LABEL_COLUMNS])
    self.assertAllEqual(expected_activations_shape, activations.shape)
    expected_state_shape = np.array([3, self.NUM_RNN_CELL_UNITS])
    self.assertAllEqual(expected_state_shape, final_state.shape)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号