classifier.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:quoll 作者: LanguageMachines 项目源码 文件源码
def train_classifier(self, trainvectors, labels, alpha='default', fit_prior=True, iterations=10):
        if alpha == '':
            paramsearch = GridSearchCV(estimator=naive_bayes.MultinomialNB(), param_grid=dict(alpha=numpy.linspace(0,2,20)[1:]), n_jobs=6)
            paramsearch.fit(trainvectors,self.label_encoder.transform(labels))
            selected_alpha = paramsearch.best_estimator_.alpha
        elif alpha == 'default':
            selected_alpha = 1.0
        else:
            selected_alpha = alpha
        if fit_prior == 'False':
            fit_prior = False
        else:
            fit_prior = True
        self.model = naive_bayes.MultinomialNB(alpha=selected_alpha,fit_prior=fit_prior)
        self.model.fit(trainvectors, self.label_encoder.transform(labels))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号