train_novelty_detection.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:keras-transfer-learning-for-oxford102 作者: Arsey 项目源码 文件源码
def train_logistic():
    df = pd.read_csv(config.activations_path)
    df, y, classes = encode(df)

    X_train, X_test, y_train, y_test = train_test_split(df.values, y, test_size=0.2, random_state=17)

    params = {'C': [10, 2, .9, .4, .1], 'tol': [0.0001, 0.001, 0.0005]}
    log_reg = LogisticRegression(solver='lbfgs', multi_class='multinomial', class_weight='balanced')
    clf = GridSearchCV(log_reg, params, scoring='neg_log_loss', refit=True, cv=3, n_jobs=-1)
    clf.fit(X_train, y_train)

    print("best params: " + str(clf.best_params_))
    print("Accuracy: ", accuracy_score(y_test, clf.predict(X_test)))

    setattr(clf, '__classes', classes)
    # save results for further using
    joblib.dump(clf, config.get_novelty_detection_model_path())
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号