def plot_fig(values, title, x_str, y_str, path, best_iter, std_vals=None):
"""Plot some values.
Input:
values: list or numpy.ndarray of values to plot (y)
title: string; the title of the plot.
x_str: string; the name of the x axis.
y_str: string; the name of the y axis.
path: string; path where to save the figure.
best_iter: integer. The epoch of the best iteration.
std_val: List or numpy.ndarray of standad deviation values that
corresponds to each value in 'values'.
"""
floating = 6
prec = "%." + str(floating) + "f"
if best_iter >= 0:
if isinstance(values, list):
if best_iter >= len(values):
best_iter = -1
if isinstance(values, np.ndarray):
if best_iter >= np.size:
best_iter = -1
v = str(prec % np.float(values[best_iter]))
else:
v = str(prec % np.float(values[-1]))
best_iter = -1
if best_iter == -1:
best_iter = len(values)
fig = plt.figure()
plt.plot(
values,
label="lower val: " + v + " at " + str(best_iter) + " " +
x_str)
plt.xlabel(x_str)
plt.ylabel(y_str)
plt.title(title, fontsize=8)
plt.legend(loc='upper right', fancybox=True, shadow=True, prop={'size': 8})
plt.grid(True)
fig.savefig(path, bbox_inches='tight')
plt.close('all')
del fig
tools.py 文件源码
python
阅读 19
收藏 0
点赞 0
评论 0
评论列表
文章目录