def lr_schedule_op(self):
lr_stage_0 = self.lr_start
lr_stage_1 = tf.constant(0.0005)
lr_stage_2 = tf.constant(0.0003)
lr_state_3 = tf.constant(0.0001)
gate_0 = tf.constant(int(5e5), dtype=tf.int32)
gate_1 = tf.constant(int(1e6), dtype=tf.int32)
gate_2 = tf.constant(int(2e6), dtype=tf.int32)
def f1(): return lr_stage_0
def f2(): return lr_stage_1
def f3(): return lr_stage_2
def f4(): return lr_stage_3
new_lr = case([(tf.less(self.global_step, gate_0), f1), (tf.less(self.global_step, gate_1), f2),\
(tf.less(self.global_step, gate_2), f3)],
default=f4, exclusive=False)
return self.learning_rate.assign(new_lr)
评论列表
文章目录