visualize.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:cifar10-tensorflow 作者: namakemono 项目源码 文件源码
def save_solvers_cmp(is_power_point = False):
    dfs = []
    for filename in glob.glob("../output/cifar10classifier_resnet32_*.csv"):
        target = filename.split("_")[-1].split(".csv")[0] 
        if target in ["adadelta", "adagrad", "adam", "momentum", "rmsprop"]:
            df = pd.read_csv(filename)
            df["train_error"] = 1 - df["train_accuracy"]
            df["test_error"] = 1 - df["test_accuracy"]
            dfs.append(df)
    total_df = pd.concat(dfs)
    total_df["name"] = total_df["name"].str.split("_").str.get(-1).str.replace("Momentum", "Nesterov(Original Paper)")
    ax = sns.pointplot(x="epoch", y="test_error", hue="name", data=total_df, scale=0.2)
    if is_power_point:
        ax.legend(loc="lower left", markerscale=9.0, fontsize=20)  
    else:
        ax.legend(loc="lower left", markerscale=3.0)
    ax.set(ylim=(0, 0.2))
    ax.set_xticklabels([i if i % 10 == 0 else "" for i in range(200)])
    ax.set(xlabel='epoch', ylabel='error(%)')
    ax.get_figure().savefig("../figures/resnet.solvers.png")
    sns.plt.close()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号