test_hpo.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:brainiak 作者: brainiak 项目源码 文件源码
def test_simple_hpo():

    def f(args):
        x = args['x']
        return x*x

    s = {'x': {'dist': st.uniform(loc=-10., scale=20), 'lo': -10., 'hi': 10.}}
    trials = []

    # Test fmin and ability to continue adding to trials
    best = fmin(loss_fn=f, space=s, max_evals=40, trials=trials)
    best = fmin(loss_fn=f, space=s, max_evals=10, trials=trials)

    assert len(trials) == 50, "HPO continuation trials not working"

    # Test verbose flag
    best = fmin(loss_fn=f, space=s, max_evals=10, trials=trials)

    yarray = np.array([tr['loss'] for tr in trials])
    np.testing.assert_array_less(yarray, 100.)

    xarray = np.array([tr['x'] for tr in trials])
    np.testing.assert_array_less(np.abs(xarray), 10.)

    assert best['loss'] < 100., "HPO out of range"
    assert np.abs(best['x']) < 10., "HPO out of range"

    # Test unknown distributions
    s2 = {'x': {'dist': 'normal', 'mu': 0., 'sigma': 1.}}
    trials2 = []
    with pytest.raises(ValueError) as excinfo:
        fmin(loss_fn=f, space=s2, max_evals=40, trials=trials2)
    assert "Unknown distribution type for variable" in str(excinfo.value)

    s3 = {'x': {'dist': st.norm(loc=0., scale=1.)}}
    trials3 = []
    fmin(loss_fn=f, space=s3, max_evals=40, trials=trials3)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号