def test_fixed_lr(iter_buf, max_iter, base_lr):
# set up
name = 'fixed'
params = {'name': name,
'max_iter': max_iter,
'base_lr': base_lr}
# execute
naive_lr = np.full(max_iter, base_lr)
lr_op = lr_policies[name]['obj'](params)(iter_buf)
with ExecutorFactory() as ex:
compute_lr = ex.executor(lr_op, iter_buf)
ng_lr = [compute_lr(i).item(0) for i in range(max_iter)]
# compare
ng.testing.assert_allclose(ng_lr, naive_lr, atol=1e-4, rtol=1e-3)
评论列表
文章目录