def display_tsne(train_x, train_y, label_map=None):
"""
t-distributed Stochastic Neighbor Embedding (t-SNE) visualization [1].
[1]: Maaten, L., Hinton, G. (2008). Visualizing Data using t-SNE.
JMLR 9(Nov):2579--2605.
Args:
train_x: 2d numpy array (batch, features) of samples
train_y: 2d numpy array (batch, labels) for samples
label_map: a dict of labelled (str(int), string) key, value pairs
"""
tsne = TSNE(n_components=2, random_state=0)
x_transform = tsne.fit_transform(train_x)
y_unique = np.unique(train_y)
if label_map is None:
label_map = {str(i): str(i) for i in y_unique}
elif not isinstance(label_map, dict):
raise ValueError('label_map most be a dict of a key'
' mapping to its true label')
colours = plt.cm.rainbow(np.linspace(0, 1, len(y_unique)))
plt.figure()
for index, cl in enumerate(y_unique):
plt.scatter(x=x_transform[train_y == cl, 0],
y=x_transform[train_y == cl, 1],
s=100,
c=colours[index],
marker='o',
edgecolors='none',
label=label_map[str(cl)])
plt.xlabel('X in t-SNE')
plt.ylabel('Y in t-SNE')
plt.legend(loc='upper right')
plt.title('t-SNE visualization')
plt.show(False)
评论列表
文章目录