classifier_maker.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:data-analysis 作者: ymohanty 项目源码 文件源码
def main(argv):
    if len(argv) < 4:
        print 'Usage: python %s <classification method> <train data file> <test data file> <optional train categories> <optional test categories>' % (argv[0])
        exit(-1)

    elif len(argv) > 4:
        print "Reading data..."
        training_data,training_labels, dOb_train = read_training_data(argv[2],argv[4])
        testing_data, testing_labels, dObj_test = read_testing_data(argv[3],argv[5])

    else:
        training_data, training_labels, dOb_train = read_training_data(argv[2])
        testing_data, testing_labels, dObj_test = read_testing_data(argv[3])

    print "Building the Classifier..."

    classifier = build_classifier(training_data,training_labels,argv[1])

    print "Classifying test and training data..."

    ctraincats, ctrainlabels = classifier.classify(training_data)
    ctestcats, ctestlabels = classifier.classify(testing_data)

    # recast labels to [0-C-1]
    unique1, mapping1 = np.unique(training_labels.T.tolist()[0],return_inverse=True)
    unique2, mapping2 = np.unique(testing_labels.T.tolist()[0], return_inverse=True)

    mapping1 = np.matrix(mapping1).T
    mapping2 = np.matrix(mapping2).T

    print "Constructing the Confusion matrices"

    cmtx_train = classifier.confusion_matrix(mapping1,ctraincats)
    cmtx_test = classifier.confusion_matrix(mapping2,ctestcats)

    print cmtx_train
    print cmtx_test

    print "\nTraining Data"
    print classifier.confusion_matrix_str(cmtx_train)
    print "\nTesting Data"
    print classifier.confusion_matrix_str(cmtx_test)

    print "Writing to file"

    dObj_test.add_column("class","numeric",ctestcats.T.tolist()[0])

    dObj_test.write_to_file(argv[3].split('.')[0] + "-" + argv[1] + "-classified",dObj_test.get_headers())
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号