svm_utils.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:ml_defense 作者: arjunbhagoji 项目源码 文件源码
def model_trainer(model_dict, X_train, y_train, adv=None, rd=None, rev=None):
    """Trains and returns SVM. Also save SVM to file."""

    print('Training model...')
    start_time = time.time()
    abs_path_m = resolve_path_m(model_dict)
    svm_model = model_dict['svm_type']
    C = model_dict['penconst']
    penalty = model_dict['penalty']
    if adv is None:
        adv_mag = None

    # Create model based on parameters
    if svm_model == 'linear':
        dual = True
        if penalty == 'l1':
            dual = False
        clf = svm.LinearSVC(C=C, penalty=penalty, dual=dual)
        # clf = linear_model.SGDClassifier(alpha=C,l1_ratio=0)
    elif svm_model != 'linear':
        clf = svm.SVC(C=C, kernel=svm_model)

    # Train model
    clf.fit(X_train, y_train)
    print('Finish training in {:d}s'.format(int(time.time() - start_time)))

    # Save model
    joblib.dump(clf, abs_path_m +
                get_svm_model_name(model_dict, rd, rev) + '.pkl')
    return clf
#------------------------------------------------------------------------------#
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号