def plot_corrmat(in_csv, out_file=None):
import seaborn as sn
sn.set(style="whitegrid")
dataframe = pd.read_csv(in_csv, index_col=False, na_values='n/a', na_filter=False)
colnames = dataframe.columns.ravel().tolist()
for col in ['subject_id', 'site', 'modality']:
try:
colnames.remove(col)
except ValueError:
pass
# Correlation matrix
corr = dataframe[colnames].corr()
corr = corr.dropna((0,1), 'all')
# Generate a mask for the upper triangle
mask = np.zeros_like(corr, dtype=np.bool)
mask[np.triu_indices_from(mask)] = True
# Generate a custom diverging colormap
cmap = sn.diverging_palette(220, 10, as_cmap=True)
# Draw the heatmap with the mask and correct aspect ratio
corrplot = sn.clustermap(corr, cmap=cmap, center=0., method='average', square=True, linewidths=.5)
plt.setp(corrplot.ax_heatmap.yaxis.get_ticklabels(), rotation='horizontal')
# , mask=mask, square=True, linewidths=.5, cbar_kws={"shrink": .5})
if out_file is None:
out_file = 'corr_matrix.svg'
fname, ext = op.splitext(out_file)
if ext[1:] not in ['pdf', 'svg', 'png']:
ext = '.svg'
out_file = fname + '.svg'
corrplot.savefig(out_file, format=ext[1:], bbox_inches='tight', pad_inches=0, dpi=100)
return corrplot
评论列表
文章目录