beam_search_test.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:seq2seq 作者: google 项目源码 文件源码
def test_step(self):
    beam_state = beam_search.BeamSearchState(
        log_probs=tf.nn.log_softmax(tf.ones(self.config.beam_width)),
        lengths=tf.constant(
            2, shape=[self.config.beam_width], dtype=tf.int32),
        finished=tf.zeros(
            [self.config.beam_width], dtype=tf.bool))

    logits_ = np.full([self.config.beam_width, self.config.vocab_size], 0.0001)
    logits_[0, 2] = 1.9
    logits_[0, 3] = 2.1
    logits_[1, 3] = 3.1
    logits_[1, 4] = 0.9
    logits = tf.convert_to_tensor(logits_, dtype=tf.float32)
    log_probs = tf.nn.log_softmax(logits)

    outputs, next_beam_state = beam_search.beam_search_step(
        time_=2, logits=logits, beam_state=beam_state, config=self.config)

    with self.test_session() as sess:
      outputs_, next_state_, state_, log_probs_ = sess.run(
          [outputs, next_beam_state, beam_state, log_probs])

    np.testing.assert_array_equal(outputs_.predicted_ids, [3, 3, 2])
    np.testing.assert_array_equal(outputs_.beam_parent_ids, [1, 0, 0])
    np.testing.assert_array_equal(next_state_.lengths, [3, 3, 3])
    np.testing.assert_array_equal(next_state_.finished, [False, False, False])

    expected_log_probs = state_.log_probs[[1, 0, 0]]
    expected_log_probs[0] += log_probs_[1, 3]
    expected_log_probs[1] += log_probs_[0, 3]
    expected_log_probs[2] += log_probs_[0, 2]
    np.testing.assert_array_equal(next_state_.log_probs, expected_log_probs)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号