def plot_tsne_3d(doc_codes, doc_labels, classes_to_visual, save_file, maker_size=None, opaque=None):
markers = ["D", "p", "*", "s", "d", "8", "^", "H", "v", ">", "<", "h", "|"]
plt.rc('legend',**{'fontsize':20})
colors = ['r', 'b', 'g', 'c', 'm', 'y', 'k']
C = len(classes_to_visual)
while True:
if C <= len(markers):
break
markers += markers
while True:
if C <= len(colors):
break
colors += colors
class_ids = dict(zip(classes_to_visual, range(C)))
if isinstance(doc_codes, dict) and isinstance(doc_labels, dict):
codes, labels = zip(*[(code, doc_labels[doc]) for doc, code in doc_codes.items() if doc_labels[doc] in classes_to_visual])
else:
codes, labels = doc_codes, doc_labels
X = np.r_[list(codes)]
tsne = TSNE(perplexity=30, n_components=3, init='pca', n_iter=5000)
np.set_printoptions(suppress=True)
X = tsne.fit_transform(X)
fig = plt.figure(figsize=(10, 10), facecolor='white')
ax = fig.add_subplot(111, projection='3d')
# The problem is that the legend function don't support the type returned by a 3D scatter.
# So you have to create a "dummy plot" with the same characteristics and put those in the legend.
scatter_proxy = []
for i in range(C):
cls = classes_to_visual[i]
idx = np.array(labels) == cls
ax.scatter(X[idx, 0], X[idx, 1], X[idx, 2], c=colors[i], alpha=opaque[i] if opaque else 1, s=maker_size[i] if maker_size else 20, marker=markers[i], label=cls)
scatter_proxy.append(mpl.lines.Line2D([0],[0], linestyle="none", c=colors[i], marker=markers[i], label=cls))
ax.legend(scatter_proxy, classes_to_visual, numpoints=1)
plt.savefig(save_file)
plt.show()
评论列表
文章目录