decoder_test.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:seq2seq 作者: google 项目源码 文件源码
def test_with_dynamic_inputs(self):
    embeddings = tf.get_variable("W_embed", [self.vocab_size, self.input_depth])

    helper = decode_helper.GreedyEmbeddingHelper(
        embedding=embeddings, start_tokens=[0] * self.batch_size, end_token=-1)
    decoder_fn = self.create_decoder(
        helper=helper, mode=tf.contrib.learn.ModeKeys.INFER)
    initial_state = decoder_fn.cell.zero_state(
        self.batch_size, dtype=tf.float32)
    decoder_output, _ = decoder_fn(initial_state, helper)

    #pylint: disable=E1101
    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      decoder_output_ = sess.run(decoder_output)

    np.testing.assert_array_equal(
        decoder_output_.logits.shape,
        [self.max_decode_length, self.batch_size, self.vocab_size])
    np.testing.assert_array_equal(decoder_output_.predicted_ids.shape,
                                  [self.max_decode_length, self.batch_size])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号