utils.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:Vulcan 作者: rfratila 项目源码 文件源码
def display_tsne(train_x, train_y, label_map=None):
    """
    t-distributed Stochastic Neighbor Embedding (t-SNE) visualization [1].

    [1]: Maaten, L., Hinton, G. (2008). Visualizing Data using t-SNE.
            JMLR 9(Nov):2579--2605.

    Args:
        train_x: 2d numpy array (batch, features) of samples
        train_y: 2d numpy array (batch, labels) for samples
        label_map: a dict of labelled (str(int), string) key, value pairs
    """
    tsne = TSNE(n_components=2, random_state=0)
    x_transform = tsne.fit_transform(train_x)
    y_unique = np.unique(train_y)
    if label_map is None:
        label_map = {str(i): str(i) for i in y_unique}
    elif not isinstance(label_map, dict):
        raise ValueError('label_map most be a dict of a key'
                         ' mapping to its true label')
    colours = plt.cm.rainbow(np.linspace(0, 1, len(y_unique)))
    plt.figure()
    for index, cl in enumerate(y_unique):
        plt.scatter(x=x_transform[train_y == cl, 0],
                    y=x_transform[train_y == cl, 1],
                    s=100,
                    c=colours[index],
                    marker='o',
                    edgecolors='none',
                    label=label_map[str(cl)])
    plt.xlabel('X in t-SNE')
    plt.ylabel('Y in t-SNE')
    plt.legend(loc='upper right')
    plt.title('t-SNE visualization')
    plt.show(False)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号