pooling_encoder_test.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:seq2seq 作者: google 项目源码 文件源码
def _test_with_params(self, params):
    """Tests the encoder with a given parameter configuration"""
    inputs = tf.random_normal(
        [self.batch_size, self.sequence_length, self.input_depth])
    example_length = tf.ones(
        self.batch_size, dtype=tf.int32) * self.sequence_length

    encode_fn = PoolingEncoder(params, self.mode)
    encoder_output = encode_fn(inputs, example_length)

    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      encoder_output_ = sess.run(encoder_output)

    np.testing.assert_array_equal(
        encoder_output_.outputs.shape,
        [self.batch_size, self.sequence_length, self.input_depth])
    np.testing.assert_array_equal(
        encoder_output_.attention_values.shape,
        [self.batch_size, self.sequence_length, self.input_depth])
    np.testing.assert_array_equal(encoder_output_.final_state.shape,
                                  [self.batch_size, self.input_depth])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号