def generate_plot(self,filename,title='',xlabel='',ylabel='',xlim=None,ylim=None):
logger = logging.getLogger("plotting")
logger.debug('MultipleSeriesPlot.generate_plot')
# a plot with one or more time series sharing a common x axis:
# e.g., the training error and the validation error plotted against epochs
# sort the data series and make sure they are consistent
self.sort_and_validate()
# if there is a plot already in existence, we will clear it and re-use it;
# this avoids creating extraneous figures which will stay in memory
# (even if we are no longer referencing them)
if self.plot:
self.plot.clf()
else:
# create a plot
self.plot = plt.figure()
splt = self.plot.add_subplot(1, 1, 1)
splt.set_title(title)
splt.set_xlabel(xlabel)
splt.set_ylabel(ylabel)
if xlim:
pylab.xlim(xlim)
if ylim:
pylab.ylim(ylim)
for series_name,data_points in self.data.items():
xpoints=numpy.asarray([seq[0] for seq in data_points])
ypoints=numpy.asarray([seq[1] for seq in data_points])
line, = splt.plot(xpoints, ypoints, '-', linewidth=2)
logger.debug('set_label for %s' % series_name)
line.set_label(series_name)
splt.legend()
# TO DO - better filename configuration for plots
self.plot.savefig(filename)
评论列表
文章目录