def _anneal_weight(init_val, final_val, anneal_type, global_step, anneal_steps, hold_for=0., steps_div=1.,
dtype=tf.float64):
val, final, step, hold_for, anneal_steps, steps_div = (tf.cast(i, dtype) for i in
(init_val, final_val, global_step, hold_for, anneal_steps, steps_div))
step = tf.maximum(step - hold_for, 0.)
if anneal_type == 'exp':
decay_rate = tf.pow(final / val, steps_div / anneal_steps)
val = tf.train.exponential_decay(val, step, steps_div, decay_rate)
elif anneal_type == 'linear':
val = final + (val - final) * (1. - step / anneal_steps)
else:
raise NotImplementedError
anneal_weight = tf.maximum(final, val)
return anneal_weight
评论列表
文章目录