def plot_loss(train_losses, dev_losses, steps, save_path):
"""Save history of training & dev loss as figure.
Args:
train_losses (list): train losses
dev_losses (list): dev losses
steps (list): steps
"""
# Save as csv file
loss_graph = np.column_stack((steps, train_losses, dev_losses))
if os.path.isfile(os.path.join(save_path, "ler.csv")):
os.remove(os.path.join(save_path, "ler.csv"))
np.savetxt(os.path.join(save_path, "loss.csv"), loss_graph, delimiter=",")
# TODO: error check for inf loss
# Plot & save as png file
plt.clf()
plt.plot(steps, train_losses, blue, label="Train")
plt.plot(steps, dev_losses, orange, label="Dev")
plt.xlabel('step', fontsize=12)
plt.ylabel('loss', fontsize=12)
plt.legend(loc="upper right", fontsize=12)
if os.path.isfile(os.path.join(save_path, "loss.png")):
os.remove(os.path.join(save_path, "loss.png"))
plt.savefig(os.path.join(save_path, "loss.png"), dvi=500)
plot.py 文件源码
python
阅读 26
收藏 0
点赞 0
评论 0
评论列表
文章目录