def adjust_max(start, stop, start_value, stop_value, name=None):
with ops.name_scope(name, "AdjustMax",
[start, stop, name]) as name:
global_step = tf.train.get_global_step()
if global_step is not None:
start = tf.convert_to_tensor(start, dtype=tf.int64)
stop = tf.convert_to_tensor(stop, dtype=tf.int64)
start_value = tf.convert_to_tensor(start_value, dtype=tf.float32)
stop_value = tf.convert_to_tensor(stop_value, dtype=tf.float32)
pred_fn_pairs = {}
pred_fn_pairs[global_step <= start] = lambda: start_value
pred_fn_pairs[(global_step > start) & (global_step <= stop)] = lambda: tf.train.polynomial_decay(
start_value, global_step-start, stop-start,
end_learning_rate=stop_value, power=1.0, cycle=False)
default = lambda: stop_value
return tf.case(pred_fn_pairs, default, exclusive=True)
else:
return None
评论列表
文章目录