builders.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:dynamic-training-bench 作者: galeone 项目源码 文件源码
def build_optimizer(args, steps, global_step):
    """Build the specified optimizer, log the learning rate and enalble
    learning rate decay is specified.
    Args:
        args: the optimization argument dict
        global_step: integer tensor, the current training step
    Returns:
        optimizer: tf.Optimizer object initialized
    """
    # Extract the initial learning rate
    initial_lr = float(args["gd"]["args"]['learning_rate'])

    if args["lr_decay"]["enabled"]:
        # Decay the learning rate exponentially based on the number of steps.
        learning_rate = tf.train.exponential_decay(
            initial_lr,
            global_step,
            steps["decay"],
            args["lr_decay"]["factor"],
            staircase=True)
        # Update the learning rate parameter of the optimizer
        args["gd"]["args"]['learning_rate'] = learning_rate
        # Log the learning rate
        tf_log(tf.summary.scalar('learning_rate', learning_rate))
    else:
        learning_rate = tf.constant(initial_lr)

    # Instantiate the optimizer
    optimizer = args["gd"]["optimizer"](**args["gd"]["args"])
    return optimizer
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号