visualize.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:cifar10-tensorflow 作者: namakemono 项目源码 文件源码
def save_layers_cmp(is_power_point = False):
    total_df = None
    for layer in [20, 32, 44, 56, 110]:
        df = pd.read_csv("../output/cifar10classifier_resnet%d.csv" % layer)
        df["train_error"] = 1 - df["train_accuracy"]
        df["test_error"] = 1 - df["test_accuracy"]
        df = df[df["epoch"] < 150]
        if total_df is None:
            total_df = df
        else:
            total_df = pd.concat([total_df, df])
    total_df["name"] = total_df["name"].str.split("_").str.get(-1)
    ax = sns.pointplot(x="epoch", y="test_error", hue="name", data=total_df, scale=0.2)
    if is_power_point:
        ax.legend(loc="lower left", markerscale=9.0, fontsize=20)  
    else:
        ax.legend(markerscale=3.0)
    ax.set(ylim=(0, 0.2))
    ax.set_xticklabels([i if i % 10 == 0 else "" for i in range(150)])
    ax.set(xlabel='epoch', ylabel='error(%)')
    ax.get_figure().savefig("../figures/resnet.layers.png")
    sns.plt.close()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号