def test_step(self):
beam_state = beam_search.BeamSearchState(
log_probs=tf.nn.log_softmax(tf.ones(self.config.beam_width)),
lengths=tf.constant(
2, shape=[self.config.beam_width], dtype=tf.int32),
finished=tf.zeros(
[self.config.beam_width], dtype=tf.bool))
logits_ = np.full([self.config.beam_width, self.config.vocab_size], 0.0001)
logits_[0, 2] = 1.9
logits_[0, 3] = 2.1
logits_[1, 3] = 3.1
logits_[1, 4] = 0.9
logits = tf.convert_to_tensor(logits_, dtype=tf.float32)
log_probs = tf.nn.log_softmax(logits)
outputs, next_beam_state = beam_search.beam_search_step(
time_=2, logits=logits, beam_state=beam_state, config=self.config)
with self.test_session() as sess:
outputs_, next_state_, state_, log_probs_ = sess.run(
[outputs, next_beam_state, beam_state, log_probs])
np.testing.assert_array_equal(outputs_.predicted_ids, [3, 3, 2])
np.testing.assert_array_equal(outputs_.beam_parent_ids, [1, 0, 0])
np.testing.assert_array_equal(next_state_.lengths, [3, 3, 3])
np.testing.assert_array_equal(next_state_.finished, [False, False, False])
expected_log_probs = state_.log_probs[[1, 0, 0]]
expected_log_probs[0] += log_probs_[1, 3]
expected_log_probs[1] += log_probs_[0, 3]
expected_log_probs[2] += log_probs_[0, 2]
np.testing.assert_array_equal(next_state_.log_probs, expected_log_probs)
评论列表
文章目录