def get_local_words(preds, vocab, NEs=[], k=50):
#normalize the probabilites of each vocab
normalized_preds = normalize(preds, norm='l1', axis=0)
entropies = stats.entropy(normalized_preds)
sorted_indices = np.argsort(entropies)
sorted_local_words = np.array(vocab)[sorted_indices].tolist()
filtered_local_words = []
NEset = set(NEs)
for word in sorted_local_words:
if word in NEset: continue
filtered_local_words.append(word)
return filtered_local_words[0:k]
评论列表
文章目录