def rf_categorize(email):
# get training corpus
emails = []
db = utils.get_local_db()
for collection in db.collection_names():
for record in db.get_collection(collection).find():
emails.append([collection] + [record['Text']])
# vectorize corpus
labels = [row[0] for row in emails]
data = [row[1] for row in emails]
vectorizer = TfidfVectorizer()
X = vectorizer.fit_transform(data)
X = X.toarray()
# vectorize input
email_vector = vectorizer.transform([email])
# create random forest and return prediction
forest = RandomForestClassifier(n_estimators = int(sqrt(len(X[0])))+1)
forest.fit(X, labels)
return forest.predict(email_vector)[0]
评论列表
文章目录