sgd.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:hyperband 作者: zygmuntz 项目源码 文件源码
def try_params( n_iterations, params ):

    n_iterations = int( round( n_iterations ))
    print "n_iterations:", n_iterations
    pprint( params )

    if params['scaler']:
        scaler = eval( "{}()".format( params['scaler'] ))
        x_train_ = scaler.fit_transform( data['x_train'].astype( float ))
        x_test_ = scaler.transform( data['x_test'].astype( float ))

        local_data = { 'x_train': x_train_, 'y_train': data['y_train'], 
          'x_test': x_test_, 'y_test': data['y_test'] }
    else:
        local_data = data

    # we need a copy because at the next small round the best params will be re-used
    params_ = dict( params )
    params_.pop( 'scaler' )

    clf = SGD( n_iter = n_iterations, **params_ )

    return train_and_eval_sklearn_classifier( clf, local_data )
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号