methods.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:South-African-Heart-Disease-data-analysis-using-python 作者: khushi4tiwari 项目源码 文件源码
def decisionTree(X,y,attributeNames,classNames,fileName,s="",X_train=None,y_train=None, X_test=None, y_test=None):
    print "Doing decision tree for: "
    print s

    if(X_train is None or X_test is None or y_train is None or y_test is None):
        X_train = X
        X_test = X
        y_train = y
        y_test = y

    # Fit regression tree classifier, Gini split criterion, pruning enabled
    dtc = tree.DecisionTreeClassifier(criterion='gini', min_samples_split=100)
    dtc = dtc.fit(X_train,y_train)

    # Export tree graph for visualization purposes:
    # (note: you can use i.e. Graphviz application to visualize the file)
    out = tree.export_graphviz(dtc, out_file=fileName, feature_names=attributeNames)
    out.close()

    correct = 0
    wrong = 0

    for i in range(0,len(X_test)):
        x = X_test[i,:]
        x_class = dtc.predict(x)[0]
        if((x_class < 0.5 and y_test[i] < 0.5) or (x_class > 0.5 and y_test[i] > 0.5)):
            correct += 1
        else:
            wrong += 1

    rate = double(wrong) / double(correct + wrong)            
    print rate
    print '\n'

    return rate
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号