def reuters_visualize_tsne(doc_codes, doc_labels, classes_to_visual, save_file):
"""
Visualize the input data on a 2D PCA plot. Depending on the number of components,
the plot will contain an X amount of subplots.
@param doc_codes:
@param number_of_components: The number of principal components for the PCA plot.
"""
# markers = ["p", "s", "h", "H", "+", "x", "D"]
markers = ["o", "v", "8", "s", "p", "*", "h", "H", "+", "x", "D"]
C = len(classes_to_visual)
while True:
if C <= len(markers):
break
markers += markers
class_names = classes_to_visual.keys()
class_ids = dict(zip(class_names, range(C)))
class_names = set(class_names)
codes, labels = zip(*[(code, doc_labels[doc]) for doc, code in doc_codes.items() if class_names.intersection(set(doc_labels[doc]))])
X = np.r_[list(codes)]
tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=5000)
np.set_printoptions(suppress=True)
X = tsne.fit_transform(X)
plt.figure(figsize=(10, 10), facecolor='white')
for c in classes_to_visual.keys():
idx = get_indices(labels, c)
plt.plot(X[idx, 0], X[idx, 1], linestyle='None', alpha=0.6, marker=markers[class_ids[c]],
markersize=6, label=classes_to_visual[c])
legend = plt.legend(loc='upper center', shadow=True)
plt.title("tsne")
plt.savefig(save_file)
plt.show()
评论列表
文章目录