def plot_heatmaps(data, labels, alpha, mis, column_label, cont, topk=20, prefix='', focus=''):
cmap = sns.cubehelix_palette(as_cmap=True, light=.9)
m, nv = mis.shape
for j in range(m):
inds = np.where(np.logical_and(alpha[j] > 0, mis[j] > 0.))[0]
inds = inds[np.argsort(- alpha[j, inds] * mis[j, inds])][:topk]
if focus in column_label:
ifocus = column_label.index(focus)
if not ifocus in inds:
inds = np.insert(inds, 0, ifocus)
if len(inds) >= 2:
plt.clf()
order = np.argsort(cont[:,j])
subdata = data[:, inds][order].T
subdata -= np.nanmean(subdata, axis=1, keepdims=True)
subdata /= np.nanstd(subdata, axis=1, keepdims=True)
columns = [column_label[i] for i in inds]
sns.heatmap(subdata, vmin=-3, vmax=3, cmap=cmap, yticklabels=columns, xticklabels=False, mask=np.isnan(subdata))
filename = '{}/heatmaps/group_num={}.png'.format(prefix, j)
if not os.path.exists(os.path.dirname(filename)):
os.makedirs(os.path.dirname(filename))
plt.title("Latent factor {}".format(j))
plt.savefig(filename, bbox_inches='tight')
plt.close('all')
#plot_rels(data[:, inds], list(map(lambda q: column_label[q], inds)), colors=cont[:, j],
# outfile=prefix + '/relationships/group_num=' + str(j), latent=labels[:, j], alpha=0.1)
评论列表
文章目录