def sample(self, logits, time):
rl_time_steps = tf.floordiv(tf.maximum(self.global_step_tensor -
self.burn_in_step, 0),
self.increment_step)
start_rl_step = self.sequence_length - rl_time_steps
next_input_ids = tf.cond(
tf.greater_equal(time, self.max_sequence_length),
lambda: tf.tile([self.eos_id], [self.batch_size]),
lambda: self._input_tas.read(time))
next_predicted_ids = tf.squeeze(tf.multinomial(logits, 1), axis=[-1])
mask = tf.to_int32(time >= start_rl_step)
return (1 - mask) * tf.to_int32(next_input_ids) + mask * tf.to_int32(
next_predicted_ids)
评论列表
文章目录