learning_rate.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:Buffe 作者: bentzinir 项目源码 文件源码
def create_learning_rate_func(solver_params):
    base = tt.fscalar('base')
    gamma = tt.fscalar('gamma')
    power = tt.fscalar('power')
    itrvl = tt.fscalar('itrvl')
    iter = tt.scalar('iter')

    if solver_params['lr_type']=='inv':
        lr_ = base * tt.pow(1 + gamma * iter, -power)

        lr = t.function(
            inputs=[iter, t.Param(base, default=solver_params['base']), t.Param(gamma, default=solver_params['gamma']), t.Param(power, default=solver_params['power'])],
            outputs=lr_)

    elif solver_params['lr_type']=='fixed':
        lr_ = base

        lr = t.function(
            inputs=[iter, t.Param(base, default=solver_params['base'])],
            outputs=lr_,
            on_unused_input='ignore')

    elif solver_params['lr_type']=='episodic':
        lr_ = base / (tt.floor(iter/itrvl) + 1)

        lr = t.function(
            inputs=[iter, t.Param(base, default=solver_params['base']), t.Param(itrvl, default=solver_params['interval'])],
            outputs=lr_,
            on_unused_input='ignore')
    return lr
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号