dcpg_train_viz.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:deepcpg 作者: cangermueller 项目源码 文件源码
def plot_lc(lc, metrics=None, outputs=False):
    lc = pd.melt(lc, id_vars=['split', 'epoch'], var_name='output')
    if metrics:
        if not isinstance(metrics, list):
            metrics = [metrics]
        tmp = '(%s)' % ('|'.join(metrics))
        lc = lc.loc[lc.output.str.contains(tmp)]
    metrics = lc.output[~lc.output.str.contains('_')].unique()
    lc['metric'] = ''

    for metric in metrics:
        lc.loc[lc.output.str.contains(metric), 'metric'] = metric
        lc.loc[lc.output == metric, 'output'] = 'mean'
        lc.output = lc.output.str.replace('_%s' % metric, '')
        lc.output = lc.output.str.replace('cpg_', '')

    if outputs:
        lc = lc.loc[lc.output != 'mean']
    else:
        lc = lc.loc[lc.output == 'mean']

    grid = sns.FacetGrid(lc, col='split', row='metric', hue='output',
                         sharey=False, size=3, aspect=1.2, legend_out=True)
    grid.map(mpl.pyplot.plot, 'epoch', 'value', linewidth=2)
    grid.set(ylabel='')
    grid.add_legend()
    return grid
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号