model.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:attend_infer_repeat 作者: akosiorek 项目源码 文件源码
def _anneal_weight(init_val, final_val, anneal_type, global_step, anneal_steps, hold_for=0., steps_div=1.,
                       dtype=tf.float64):

        val, final, step, hold_for, anneal_steps, steps_div = (tf.cast(i, dtype) for i in
                                                               (init_val, final_val, global_step, hold_for, anneal_steps, steps_div))
        step = tf.maximum(step - hold_for, 0.)

        if anneal_type == 'exp':
            decay_rate = tf.pow(final / val, steps_div / anneal_steps)
            val = tf.train.exponential_decay(val, step, steps_div, decay_rate)

        elif anneal_type == 'linear':
            val = final + (val - final) * (1. - step / anneal_steps)
        else:
            raise NotImplementedError

        anneal_weight = tf.maximum(final, val)
        return anneal_weight
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号