def plot_lc(lc, metrics=None, outputs=False):
lc = pd.melt(lc, id_vars=['split', 'epoch'], var_name='output')
if metrics:
if not isinstance(metrics, list):
metrics = [metrics]
tmp = '(%s)' % ('|'.join(metrics))
lc = lc.loc[lc.output.str.contains(tmp)]
metrics = lc.output[~lc.output.str.contains('_')].unique()
lc['metric'] = ''
for metric in metrics:
lc.loc[lc.output.str.contains(metric), 'metric'] = metric
lc.loc[lc.output == metric, 'output'] = 'mean'
lc.output = lc.output.str.replace('_%s' % metric, '')
lc.output = lc.output.str.replace('cpg_', '')
if outputs:
lc = lc.loc[lc.output != 'mean']
else:
lc = lc.loc[lc.output == 'mean']
grid = sns.FacetGrid(lc, col='split', row='metric', hue='output',
sharey=False, size=3, aspect=1.2, legend_out=True)
grid.map(mpl.pyplot.plot, 'epoch', 'value', linewidth=2)
grid.set(ylabel='')
grid.add_legend()
return grid
评论列表
文章目录