def next_inputs(self, sample_ids,name=None):
finished = math_ops.equal(sample_ids, self.config.eos_token)
all_finished = math_ops.reduce_all(finished)
next_inputs = control_flow_ops.cond(
all_finished,
# If we're finished, the next_inputs value doesn't matter
lambda: tf.nn.embedding_lookup(self.target_embedding, tf.tile([self.config.eos_token], [self.config.beam_width])),
lambda: tf.nn.embedding_lookup(self.target_embedding, sample_ids))
return all_finished, next_inputs
评论列表
文章目录