plot.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:tensorflow_end2end_speech_recognition 作者: hirofumi0810 项目源码 文件源码
def plot_loss(train_losses, dev_losses, steps, save_path):
    """Save history of training & dev loss as figure.
    Args:
        train_losses (list): train losses
        dev_losses (list): dev losses
        steps (list): steps
    """
    # Save as csv file
    loss_graph = np.column_stack((steps, train_losses, dev_losses))
    if os.path.isfile(os.path.join(save_path, "ler.csv")):
        os.remove(os.path.join(save_path, "ler.csv"))
    np.savetxt(os.path.join(save_path, "loss.csv"), loss_graph, delimiter=",")

    # TODO: error check for inf loss

    # Plot & save as png file
    plt.clf()
    plt.plot(steps, train_losses, blue, label="Train")
    plt.plot(steps, dev_losses, orange, label="Dev")
    plt.xlabel('step', fontsize=12)
    plt.ylabel('loss', fontsize=12)
    plt.legend(loc="upper right", fontsize=12)
    if os.path.isfile(os.path.join(save_path, "loss.png")):
        os.remove(os.path.join(save_path, "loss.png"))
    plt.savefig(os.path.join(save_path, "loss.png"), dvi=500)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号