def gensim_classifier():
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)
label_list = get_labels()
tweet_list = get_labelled_tweets()
# split all sentences to list of words
sentences = []
for tweet in tweet_list:
temp_doc = tweet.split()
sentences.append(temp_doc)
# parameters for model
num_features = 100
min_word_count = 1
num_workers = 4
context = 2
downsampling = 1e-3
# Initialize and train the model
w2v_model = Word2Vec(sentences, workers=num_workers, \
size=num_features, min_count = min_word_count, \
window = context, sample = downsampling, seed=1)
index_value, train_set, test_set = train_test_split(0.80, sentences)
train_vector = getAvgFeatureVecs(train_set, w2v_model, num_features)
test_vector = getAvgFeatureVecs(test_set, w2v_model, num_features)
train_vector = Imputer().fit_transform(train_vector)
test_vector = Imputer().fit_transform(test_vector)
# train model and predict
model = LinearSVC()
classifier_fitted = OneVsRestClassifier(model).fit(train_vector, label_list[:index_value])
result = classifier_fitted.predict(test_vector)
# output result to csv
create_directory('data')
result.tofile("data/w2v_linsvc.csv", sep=',')
# store the model to mmap-able files
create_directory('model')
joblib.dump(model, 'model/%s.pkl' % 'w2v_linsvc')
# evaluation
label_score = classifier_fitted.decision_function(test_vector)
binarise_result = label_binarize(result, classes=class_list)
binarise_labels = label_binarize(label_list, classes=class_list)
evaluate(binarise_result, binarise_labels[index_value:], label_score, 'w2v_linsvc')
评论列表
文章目录