def print_2D( points,label,id_map ):
'''
points: N_samples * 2
label: (int) N_samples
id_map: map label id to its name
'''
fig = plt.figure()
#current_palette = sns.color_palette("RdBu_r", max(label)+1)
n_cell,_ = points.shape
if n_cell > 500:
s = 10
else:
s = 20
ax = plt.subplot(111)
print( np.unique(label) )
for i in np.unique(label):
ax.scatter( points[label==i,0], points[label==i,1], c=current_palette[i], label=id_map[i], s=s,marker=markers_keys[i] )
box = ax.get_position()
ax.set_position([box.x0, box.y0 + box.height * 0.1,
box.width, box.height * 0.9])
ax.legend(scatterpoints=1,loc='upper center',
bbox_to_anchor=(0.5,-0.08),ncol=6,
fancybox=True,
prop={'size':8}
)
sns.despine()
return fig
评论列表
文章目录