functions.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:ECNN 作者: alazareva 项目源码 文件源码
def learning_rate(lr=LEARNING_RATE):
    decrease_rate = 0.75
    lr = lr
    window = []
    window_size = 5
    def f(loss = float('inf')):
        nonlocal window
        nonlocal lr
        nonlocal window_size
        window.append(loss)
        if len(window) == window_size:
            diffs = np.ediff1d(window)
            if np.all(abs(diffs) > np.array(window[:-1])*0.05) and np.mean(diffs > 0) >= 0.5: # if large loss
                # fluctuations
                print("fluctuating", window)
                lr *= decrease_rate
                window = []
            elif np.all(abs(diffs) < np.array(window[:-1])*0.01) and np.all(diffs < 0): # if decreased by
                # small amount
                print("too slow", window)
                lr *= 1/decrease_rate
                window = []
            else:
                window.pop(0)
        return lr
    return f
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号