test_lr.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:ngraph 作者: NervanaSystems 项目源码 文件源码
def test_fixed_lr(iter_buf, max_iter, base_lr):
    # set up
    name = 'fixed'
    params = {'name': name,
              'max_iter': max_iter,
              'base_lr': base_lr}

    # execute
    naive_lr = np.full(max_iter, base_lr)
    lr_op = lr_policies[name]['obj'](params)(iter_buf)
    with ExecutorFactory() as ex:
        compute_lr = ex.executor(lr_op, iter_buf)
        ng_lr = [compute_lr(i).item(0) for i in range(max_iter)]

        # compare
        ng.testing.assert_allclose(ng_lr, naive_lr, atol=1e-4, rtol=1e-3)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号