def build_optimizer(args, steps, global_step):
"""Build the specified optimizer, log the learning rate and enalble
learning rate decay is specified.
Args:
args: the optimization argument dict
global_step: integer tensor, the current training step
Returns:
optimizer: tf.Optimizer object initialized
"""
# Extract the initial learning rate
initial_lr = float(args["gd"]["args"]['learning_rate'])
if args["lr_decay"]["enabled"]:
# Decay the learning rate exponentially based on the number of steps.
learning_rate = tf.train.exponential_decay(
initial_lr,
global_step,
steps["decay"],
args["lr_decay"]["factor"],
staircase=True)
# Update the learning rate parameter of the optimizer
args["gd"]["args"]['learning_rate'] = learning_rate
# Log the learning rate
tf_log(tf.summary.scalar('learning_rate', learning_rate))
else:
learning_rate = tf.constant(initial_lr)
# Instantiate the optimizer
optimizer = args["gd"]["optimizer"](**args["gd"]["args"])
return optimizer
评论列表
文章目录