def plot_validation_cost(train_error, val_error, class_rate=None, savefilename=None):
epochs = range(len(train_error))
fig, ax1 = plt.subplots()
ax1.plot(epochs, train_error, label='train error')
ax1.plot(epochs, val_error, label='validation error')
ax1.set_xlabel('epoch')
ax1.set_ylabel('cost')
plt.title('Validation Cost')
lines = ax1.get_lines()
# Shrink current axis's height by 10% on the bottom
box = ax1.get_position()
ax1.set_position([box.x0, box.y0 + box.height * 0.1,
box.width, box.height * 0.9])
if class_rate is not None:
ax2 = plt.twinx(ax1)
ax2.plot(epochs, class_rate, label='classification rate', color='r')
ax2.set_ylabel('classification rate')
lines.extend(ax2.get_lines())
ax2.set_position([box.x0, box.y0 + box.height * 0.1,
box.width, box.height * 0.9])
labels = [l.get_label() for l in lines]
# Put a legend below current axis
ax1.legend(lines, labels, loc='upper center', bbox_to_anchor=(0.5, -0.05),
fancybox=False, shadow=False, ncol=5)
# ax1.legend(lines, labels, loc='lower right')
if savefilename:
plt.savefig(savefilename)
plt.show()
评论列表
文章目录