lda_tuna.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:twitter_LDA_topic_modeling 作者: kenneth-orton 项目源码 文件源码
def main(text_dir):
    topics = range(10, 101, 10) + range(120, 201, 20) + range(250, 451, 50)
    #topics = range(10, 21, 10)
    #corpus = DocCorpus(text_dir)
    #dictionary = corpus.dictionary
    corpus = MmCorpus('../twitter_LDA_topic_modeling/simple-wiki.mm')
    dictionary = Dictionary.load('../twitter_LDA_topic_modeling/simple-wiki.dict')
    print('Building LDA models')
    lda_models = [models.LdaMulticore(corpus=corpus, id2word=dictionary, num_topics=i, passes=5) for i in tqdm(topics)]
    print('Generating coherence models')
    texts = [[dictionary[word_id] for word_id, freq in doc] for doc in corpus]
    pool = multiprocessing.Pool(max(1, multiprocessing.cpu_count() - 1))
    func = partial(build_coherence_models, 
                   corpus=corpus,
                   dictionary=dictionary,
                   texts=texts)
    coherence_models = pool.map(func, lda_models)
    pool.close()
#    print('Extracting data from models')
#    model_data = [extract_data(model, corpus, dictionary) for model in tqdm(lda_models)]
#    d = defaultdict(list)
#    print('Generating output data')
#    for i, data in tqdm(enumerate(model_data)):
#        d['num_topics'].append(data['num_topics'])
#        d['cao_juan_2009'].append(cao_juan_2009(data['topic_term_dists'], data['num_topics']))
#        d['arun_2010'].append(arun_2010(data['topic_term_dists'], data['doc_topic_dists'], data['doc_lengths'], data['num_topics']))
#        d['deveaud_2014'].append(deveaud_2014(data['topic_term_dists'], data['num_topics']))
#        d['u_mass_coherence'].append(data['u_mass_coherence'])
    d = defaultdict(list)
    print('Generating output data')
    for data in tqdm(coherence_models):
        d['num_topics'].append(data['num_topics'])
        d['u_mass'].append(data['u_mass'])
        d['c_v'].append(data['c_v'])
        d['c_uci'].append(data['c_uci'])
        d['c_npmi'].append(data['c_npmi'])
    df = pd.DataFrame(d)
    df = df.set_index('num_topics')
    df.to_csv('coherence_simple_wiki', sep='\t')
    df.plot(xticks=df.index, style=['bs-', 'yo-', 'r^-', 'gx-'])
    ax1 = df.plot(xticks=df.index, style='bs-', grid=True, y='u_mass')
    ax2 = df.plot(xticks=df.index, style='yo-', grid=True, y='c_v', ax=ax1)
    ax3 = df.plot(xticks=df.index, style='r^-', grid=True, y='c_npmi', ax=ax2)
    df.plot(xticks=df.index, style='gx-', grid=True, y='c_uci', ax=ax3)
    plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.17), fancybox=True, shadow=True, ncol=4, fontsize=9)
    plt.subplots_adjust(bottom=0.2)
    plt.xticks(df.index, rotation=45, ha='right', fontsize=8)
    plt.savefig('coherence_simple_wiki')
    plt.close()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号