def next_inputs(self, time, sample_ids=None, prev_finished=None):
if sample_ids is None or self.teacher_rate > 0.:
finished = tf.greater_equal(time + 1, self.sequence_length)
else:
finished = math_ops.logical_or(
tf.greater_equal(time + 1, self.max_step),
tf.equal(self.eos_id, sample_ids))
if self.teacher_rate == 1. or (sample_ids is None):
next_input_ids = self._input_tas.read(time)
return finished, self.lookup(next_input_ids)
if self.teacher_rate > 0.:
# scheduled
teacher_rates = tf.less_equal(
tf.random_uniform(tf.shape(sample_ids), minval=0., maxval=1.),
self.teacher_rate)
teacher_rates = tf.to_int32(teacher_rates)
next_input_ids = (teacher_rates * self._input_tas.read(time)
+ (1 - teacher_rates) * sample_ids)
else:
next_input_ids = sample_ids
return finished, self.lookup(next_input_ids)
评论列表
文章目录