def main(text_dir):
topics = range(10, 101, 10) + range(120, 201, 20) + range(250, 451, 50)
#topics = range(10, 21, 10)
#corpus = DocCorpus(text_dir)
#dictionary = corpus.dictionary
corpus = MmCorpus('../twitter_LDA_topic_modeling/simple-wiki.mm')
dictionary = Dictionary.load('../twitter_LDA_topic_modeling/simple-wiki.dict')
print('Building LDA models')
lda_models = [models.LdaMulticore(corpus=corpus, id2word=dictionary, num_topics=i, passes=5) for i in tqdm(topics)]
print('Generating coherence models')
texts = [[dictionary[word_id] for word_id, freq in doc] for doc in corpus]
pool = multiprocessing.Pool(max(1, multiprocessing.cpu_count() - 1))
func = partial(build_coherence_models,
corpus=corpus,
dictionary=dictionary,
texts=texts)
coherence_models = pool.map(func, lda_models)
pool.close()
# print('Extracting data from models')
# model_data = [extract_data(model, corpus, dictionary) for model in tqdm(lda_models)]
# d = defaultdict(list)
# print('Generating output data')
# for i, data in tqdm(enumerate(model_data)):
# d['num_topics'].append(data['num_topics'])
# d['cao_juan_2009'].append(cao_juan_2009(data['topic_term_dists'], data['num_topics']))
# d['arun_2010'].append(arun_2010(data['topic_term_dists'], data['doc_topic_dists'], data['doc_lengths'], data['num_topics']))
# d['deveaud_2014'].append(deveaud_2014(data['topic_term_dists'], data['num_topics']))
# d['u_mass_coherence'].append(data['u_mass_coherence'])
d = defaultdict(list)
print('Generating output data')
for data in tqdm(coherence_models):
d['num_topics'].append(data['num_topics'])
d['u_mass'].append(data['u_mass'])
d['c_v'].append(data['c_v'])
d['c_uci'].append(data['c_uci'])
d['c_npmi'].append(data['c_npmi'])
df = pd.DataFrame(d)
df = df.set_index('num_topics')
df.to_csv('coherence_simple_wiki', sep='\t')
df.plot(xticks=df.index, style=['bs-', 'yo-', 'r^-', 'gx-'])
ax1 = df.plot(xticks=df.index, style='bs-', grid=True, y='u_mass')
ax2 = df.plot(xticks=df.index, style='yo-', grid=True, y='c_v', ax=ax1)
ax3 = df.plot(xticks=df.index, style='r^-', grid=True, y='c_npmi', ax=ax2)
df.plot(xticks=df.index, style='gx-', grid=True, y='c_uci', ax=ax3)
plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.17), fancybox=True, shadow=True, ncol=4, fontsize=9)
plt.subplots_adjust(bottom=0.2)
plt.xticks(df.index, rotation=45, ha='right', fontsize=8)
plt.savefig('coherence_simple_wiki')
plt.close()
lda_tuna.py 文件源码
python
阅读 23
收藏 0
点赞 0
评论 0
评论列表
文章目录