vis.py 文件源码

python
阅读 42 收藏 0 点赞 0 评论 0

项目:facenet_pytorch 作者: liorshk 项目源码 文件源码
def visual_feature_space(features, labels, num_classes, name_dict):
    num = len(labels)

    title_font = {'fontname':'Arial', 'size':'20', 'color':'black', 'weight':'normal',
              'verticalalignment':'bottom'} # Bottom vertical alignment for more space
    axis_font = {'fontname':'Arial', 'size':'20'}

    # draw
    palette = np.array(sns.color_palette("hls", num_classes))

    # We create a scatter plot.
    f = plt.figure(figsize=(8, 8))
    ax = plt.subplot(aspect='equal')
    sc = ax.scatter(features[:,0], features[:,1], lw=0, s=40,
                    c=palette[labels.astype(np.int)])
    # ax.axis('off')
    # ax.axis('tight')

    # We add the labels for each digit.
    txts = []
    for i in range(num_classes):
        # Position of each label.
        xtext, ytext = np.median(features[labels == i, :], axis=0)
        txt = ax.text(xtext, ytext, name_dict[i])
        txt.set_path_effects([
            PathEffects.Stroke(linewidth=5, foreground="w"),
            PathEffects.Normal()])
        txts.append(txt)
    ax.set_xlabel('Activation of the 1st neuron', **axis_font)
    ax.set_ylabel('Activation of the 2nd neuron', **axis_font)
    ax.set_title('softmax_loss + center_loss', **title_font)
    ax.set_axis_bgcolor('grey')
    f.savefig('center_loss.png')
    plt.show()
    return f, ax, sc, txts
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号