def visualize_frequent_words(vectors_2d: np.ndarray, dataset: DataSet, k: int, ax: plt.Axes = None) -> None:
word_ids, counts = np.unique(dataset.data, return_counts=True)
indices = np.argpartition(-counts, k)[:k]
frequent_word_ids = word_ids[indices]
if ax is None:
fig, ax = plt.subplots(figsize=(13, 13))
else:
fig = None
vectors_2d = vectors_2d[frequent_word_ids]
ax.scatter(vectors_2d[:, 0], vectors_2d[:, 1], s=2, alpha=0.25)
for i, id in enumerate(frequent_word_ids):
ax.annotate(dataset.vocabulary.to_word(id), (vectors_2d[i, 0], vectors_2d[i, 1]))
if fig is not None:
fig.tight_layout()
fig.show()
评论列表
文章目录