feedback.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:sequencing 作者: SwordYork 项目源码 文件源码
def sample(self, logits, time):
        rl_time_steps = tf.floordiv(tf.maximum(self.global_step_tensor -
                                               self.burn_in_step, 0),
                                    self.increment_step)
        start_rl_step = self.sequence_length - rl_time_steps

        next_input_ids = tf.cond(
            tf.greater_equal(time, self.max_sequence_length),
            lambda: tf.tile([self.eos_id], [self.batch_size]),
            lambda: self._input_tas.read(time))

        next_predicted_ids = tf.squeeze(tf.multinomial(logits, 1), axis=[-1])
        mask = tf.to_int32(time >= start_rl_step)

        return (1 - mask) * tf.to_int32(next_input_ids) + mask * tf.to_int32(
            next_predicted_ids)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号