def plot_aucs(aucs, x_col, y_col, groupby_col, colors):
"""
Scatter plot aucs[x_col] vs aucs[y_col], colored by colors[groupby_col]
Parameters
----------
aucs : pandas DataFrame
has x_col, y_col, and groupby_col
x_col, y_col, groupby_col : str
colors : dict
values in groupby_col: color to plot
"""
sns.set_style('white')
fig, ax = plt.subplots(figsize=(4,3))
ax.plot([0, 1], [0, 1], '--', c='0.95')
ax.plot([0.5, 0.5], [0, 1], '--', c='0.95')
ax.plot([0, 1], [0.5, 0.5], '--', c='0.95')
for g, subdf in aucs.groupby(groupby_col):
if g == 'cdi':
label = 'diarrhea'
else:
label = g.upper()
ax.scatter(subdf[x_col], subdf[y_col], c=colors[g], label=label)
ax.set_xlim([0, 1])
ax.set_ylim([0, 1])
fig.tight_layout()
# Shrink current axis by 20%
box = ax.get_position()
ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
# Put a legend to the right of the current axis
lgd = ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
return fig, ax, lgd
figure.healthy_vs_disease_classifier.py 文件源码
python
阅读 23
收藏 0
点赞 0
评论 0
评论列表
文章目录