newsgroup_classifier.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:nlp-playground 作者: jamesmishra 项目源码 文件源码
def main():
    """
    Train a classifier on the 20 newsgroups dataset.

    The purpose of this is mostly trying to figure out how
    to turn text into really good vector representations
    for classification... which are also hopefully good
    vector representations for unsupervised learning too.
    """
    # We don't really use our interfaces for iterating over datasets...
    # but maybe we will in the future.
    train = fetch_20newsgroups(
        subset='train',
        # categories=CATEGORIES,
        shuffle=True
    )
    test = fetch_20newsgroups(
        subset='test',
        # categories=CATEGORIES,
        shuffle=True
    )
    print("Loaded data.", len(set(train.target)), "classes.")
    glove_vectors = glove_simple()
    print("Loaded word vectors")
    pipeline = Pipeline([
        # ('vec', TfidfVectorizer()),
        ('vec', WordVectorSum(vector_dict=glove_vectors)),
        #  ('svd', TruncatedSVD()),
        ('fit', SGDClassifier())
    ])
    print("Defined pipeline. Beginning fit.")
    gridsearch = GridSearchCV(
        pipeline,
        {
            # 'vec__stop_words': ('english',),
            # 'svd__n_components': (2, 100, 500, 1000),
            # 'vec__min_df': (1, 0.01, 0.1, 0.4),
            #  'vec__max_df': (0.5, 0.75, 0.9, 1.0),
            #  'vec__max_features': (100, 1000, 10000)
        }
    )
    gridsearch.fit(train.data, train.target)
    print("Completed fit. Beginning prediction")
    predicted = gridsearch.predict(test.data)
    print("Completed prediction.")
    accuracy = np.mean(predicted == test.target)
    print("Accuracy was", accuracy)
    print("Best params", gridsearch.best_params_)
    print_best_worst(gridsearch.cv_results_)
    print(
        classification_report(
            test.target,
            predicted,
            target_names=test.target_names))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号