random_forest.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:MLAB_Intuit 作者: rykard95 项目源码 文件源码
def rf_categorize(email):
    # get training corpus
    emails = []
    db = utils.get_local_db()
    for collection in db.collection_names():
        for record in db.get_collection(collection).find():
            emails.append([collection] + [record['Text']])

    # vectorize corpus
    labels = [row[0] for row in emails]
    data = [row[1] for row in emails]
    vectorizer = TfidfVectorizer()
    X = vectorizer.fit_transform(data)
    X = X.toarray()

    # vectorize input
    email_vector = vectorizer.transform([email])

    # create random forest and return prediction
    forest = RandomForestClassifier(n_estimators = int(sqrt(len(X[0])))+1)
    forest.fit(X, labels)
    return forest.predict(email_vector)[0]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号