def beam_setup(self, time):
emit_output = None
next_cell_state = self.initial_state
next_input = self.initial_input
# Set up the beam search tracking state
cand_symbols = tf.fill([self.batch_size_times_beam_size, 0], tf.constant(self.stop_token, dtype=tf.int32))
cand_logprobs = tf.ones((self.batch_size_times_beam_size,), dtype=tf.float32) * -float('inf')
first_in_beam_mask = tf.equal(tf.range(self.batch_size_times_beam_size) % self.beam_size, 0)
beam_symbols = tf.fill([self.batch_size_times_beam_size, 0], tf.constant(self.stop_token, dtype=tf.int32))
beam_logprobs = tf.select(
first_in_beam_mask,
tf.fill([self.batch_size_times_beam_size], 0.0),
tf.fill([self.batch_size_times_beam_size], self.INVALID_SCORE)
)
# Set up correct dimensions for maintaining loop invariants.
# Note that the last dimension (initialized to zero) is not a loop invariant,
# so we need to clear it.
# inference so that _shape is not necessary?
cand_symbols._shape = tf.TensorShape((self.inferred_batch_size_times_beam_size, None))
cand_logprobs._shape = tf.TensorShape((self.inferred_batch_size_times_beam_size,))
beam_symbols._shape = tf.TensorShape((self.inferred_batch_size_times_beam_size, None))
beam_logprobs._shape = tf.TensorShape((self.inferred_batch_size_times_beam_size,))
next_loop_state = (
cand_symbols,
cand_logprobs,
beam_symbols,
beam_logprobs,
)
emit_output = tf.zeros(self.cell.output_size)
elements_finished = tf.zeros([self.batch_size], dtype=tf.bool)
return elements_finished, next_input, next_cell_state, emit_output, next_loop_state
评论列表
文章目录