main.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:semihin 作者: HKUST-KnowComp 项目源码 文件源码
def nb_experiment(scope_name, X, y):
    for lp in lp_cand:
        results = []
        for r in range(50):
            with open('data/local/split/' + scope_name + '/lb' + str(lp).zfill(3) + '_' + str(r).zfill(
                    3) + '_train') as f:
                trainLabel = pk.load(f)
            with open('data/local/split/' + scope_name + '/lb' + str(lp).zfill(3) + '_' + str(r).zfill(
                    3) + '_test') as f:
                testLabel = pk.load(f)

            XTrain = X[trainLabel.keys()]
            XTest = X[testLabel.keys()]
            if not isinstance(XTrain, np.ndarray):
                XTrain = XTrain.toarray()
                XTest = XTest.toarray()
            yTrain = y[trainLabel.keys()]
            yTest = y[testLabel.keys()]

            # train
            #clf = MultinomialNB()
            clf = GaussianNB()
            #clf = BernoulliNB()
            clf.fit(XTrain, yTrain)

            # test
            pred = clf.predict(XTest)
            results.append(sum(pred == yTest) / float(yTest.shape[0]))
        return np.mean(results)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号