def learning_rate(lr=LEARNING_RATE):
decrease_rate = 0.75
lr = lr
window = []
window_size = 5
def f(loss = float('inf')):
nonlocal window
nonlocal lr
nonlocal window_size
window.append(loss)
if len(window) == window_size:
diffs = np.ediff1d(window)
if np.all(abs(diffs) > np.array(window[:-1])*0.05) and np.mean(diffs > 0) >= 0.5: # if large loss
# fluctuations
print("fluctuating", window)
lr *= decrease_rate
window = []
elif np.all(abs(diffs) < np.array(window[:-1])*0.01) and np.all(diffs < 0): # if decreased by
# small amount
print("too slow", window)
lr *= 1/decrease_rate
window = []
else:
window.pop(0)
return lr
return f
评论列表
文章目录