common_defs.py 文件源码

python
阅读 48 收藏 0 点赞 0 评论 0

项目:hyperband 作者: zygmuntz 项目源码 文件源码
def train_and_eval_sklearn_classifier( clf, data ):

    x_train = data['x_train']
    y_train = data['y_train']

    x_test = data['x_test']
    y_test = data['y_test'] 

    clf.fit( x_train, y_train ) 

    try:
        p = clf.predict_proba( x_train )[:,1]   # sklearn convention
    except IndexError:
        p = clf.predict_proba( x_train )

    ll = log_loss( y_train, p )
    auc = AUC( y_train, p )
    acc = accuracy( y_train, np.round( p ))

    print "\n# training | log loss: {:.2%}, AUC: {:.2%}, accuracy: {:.2%}".format( ll, auc, acc )

    #

    try:
        p = clf.predict_proba( x_test )[:,1]    # sklearn convention
    except IndexError:
        p = clf.predict_proba( x_test )

    ll = log_loss( y_test, p )
    auc = AUC( y_test, p )
    acc = accuracy( y_test, np.round( p ))

    print "# testing  | log loss: {:.2%}, AUC: {:.2%}, accuracy: {:.2%}".format( ll, auc, acc ) 

    #return { 'loss': 1 - auc, 'log_loss': ll, 'auc': auc }
    return { 'loss': ll, 'log_loss': ll, 'auc': auc }

###

# "clf", even though it's a regressor
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号