newsgroups_clustering.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:nlp-playground 作者: jamesmishra 项目源码 文件源码
def main():
    """
    Cluster the newsgroups dataset and measure against labels.

    In this script, we're doing a grid search against various
    TFIDF representations of the newsgroups dataset. We want
    a TFIDF representation that has a good unsupervised
    representation.

    We're measuring the quality of that unsupervised
    representation by how well it matches up to the actual
    supervised labels of the newsgroups dataset.
    """
    newsgroups = fetch_20newsgroups(
        subset='train',
        categories=CATEGORIES,
        shuffle=True
    )
    print("Loaded data")
    gridsearch = GridSearchCV(
        Pipeline([
            ('vec', TfidfVectorizer()),
            ('cluster', ClusteringWithSupervision(
                cluster_instance=MiniBatchKMeans()))
        ]),
        {
            'vec__stop_words': (None, 'english')
        }
    )
    print("Defined pipeline. Beginning fit.")
    gridsearch.fit(newsgroups.data, newsgroups.target)
    print_best_worst(gridsearch.cv_results_)
    best_estimator = gridsearch.best_estimator_
    predicted = best_estimator.predict(newsgroups.data)
    print(
        classification_report(
            newsgroups.target,
            predicted,
            target_names=newsgroups.target_names))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号