cross_validate_SVM.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:hco-experiments 作者: zooniverse 项目源码 文件源码
def main():

    parser = optparse.OptionParser("[!] usage: python cross_validate_SVM.py -F <data file>")

    parser.add_option("-F", dest="dataFile", type="string", \
                      help="specify data file to analyse")

    (options, args) = parser.parse_args()
    dataFile = options.dataFile

    if dataFile == None:
        print parser.usage
        exit(0)

    data = sio.loadmat(dataFile)

    X = data["X"]
    m,n = np.shape(X)
    y = np.squeeze(data["y"])

    kernel_grid = ["rbf"]
    C_grid = [5]
    gamma_grid = [1]

    kf = KFold(m, n_folds=5)
    fold = 1
    for kernel in kernel_grid:
        for C in C_grid:
            for gamma in gamma_grid:
                fold=1
                FoMs = []
                for train, test in kf:
                    print "[*]", fold, kernel, C, gamma
                    file = "cv/SVM_kernel"+str(kernel)+"_C"+str(C)+\
                           "_gamma"+str(gamma)+"_"+dataFile.split("/")[-1].split(".")[0]+\
                           "_fold"+str(fold)+".pkl"
                    try:
                        svm = pickle.load(open(file,"rb"))
                    except IOError:
                        train_x, train_y = X[train], y[train]
                        svm = train_SVM(train_x, train_y, kernel, C, gamma)
                        outputFile = open(file, "wb")
                        pickle.dump(svm, outputFile)
                    FoM, threshold = measure_FoM(X[test], y[test], svm, False)
                    fold+=1
                    FoMs.append(FoM)
                print "[+] mean FoM: %.3lf" % (np.mean(np.array(FoMs)))
                print
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号