utils.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:LINE 作者: VahidooX 项目源码 文件源码
def svm_classify(X, label, split_ratios, C):
    """
    trains a linear SVM on the data
    input C specifies the penalty factor for SVM
    """
    train_size = int(len(X)*split_ratios[0])
    val_size = int(len(X)*split_ratios[1])

    train_data, valid_data, test_data = X[0:train_size], X[train_size:train_size + val_size], X[train_size + val_size:]
    train_label, valid_label, test_label = label[0:train_size], label[train_size:train_size + val_size], label[train_size + val_size:]

    print('training SVM...')
    clf = svm.SVC(C=C, kernel='linear')
    clf.fit(train_data, train_label.ravel())

    p = clf.predict(train_data)
    train_acc = accuracy_score(train_label, p)
    p = clf.predict(valid_data)
    valid_acc = accuracy_score(valid_label, p)
    p = clf.predict(test_data)
    test_acc = accuracy_score(test_label, p)

    return [train_acc, valid_acc, test_acc]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号