def DBN_plot_tsne(doc_codes, doc_labels, classes_to_visual, save_file):
markers = ["o", "v", "8", "s", "p", "*", "h", "H", "+", "x", "D"]
C = len(classes_to_visual)
while True:
if C <= len(markers):
break
markers += markers
class_ids = dict(zip(classes_to_visual.keys(), range(C)))
codes, labels = doc_codes, doc_labels
X = np.r_[list(codes)]
tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=5000)
np.set_printoptions(suppress=True)
X = tsne.fit_transform(X)
plt.figure(figsize=(10, 10), facecolor='white')
for c in classes_to_visual.keys():
idx = np.array(labels) == c
# idx = get_indices(labels, c)
plt.plot(X[idx, 0], X[idx, 1], linestyle='None', alpha=0.6, marker=markers[class_ids[c]],
markersize=6, label=classes_to_visual[c])
legend = plt.legend(loc='upper center', shadow=True)
plt.title("tsne")
plt.savefig(save_file)
plt.show()
评论列表
文章目录