logging_plotting.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:world_merlin 作者: pbaljeka 项目源码 文件源码
def generate_plot(self,filename,title='',xlabel='',ylabel='',xlim=None,ylim=None):

        logger = logging.getLogger("plotting")
        logger.debug('MultipleSeriesPlot.generate_plot')

        # a plot with one or more time series sharing a common x axis:
        # e.g., the training error and the validation error plotted against epochs

        # sort the data series and make sure they are consistent
        self.sort_and_validate()

        # if there is a plot already in existence, we will clear it and re-use it;
        # this avoids creating extraneous figures which will stay in memory
        # (even if we are no longer referencing them)
        if self.plot:
            self.plot.clf()
        else:
            # create a plot
            self.plot = plt.figure()

        splt = self.plot.add_subplot(1, 1, 1)
        splt.set_title(title)
        splt.set_xlabel(xlabel)
        splt.set_ylabel(ylabel)

        if xlim:
            pylab.xlim(xlim)
        if ylim:
            pylab.ylim(ylim)

        for series_name,data_points in self.data.iteritems():
            xpoints=numpy.asarray([seq[0] for seq in data_points])
            ypoints=numpy.asarray([seq[1] for seq in data_points])
            line, = splt.plot(xpoints, ypoints, '-', linewidth=2)
            logger.debug('set_label for %s' % series_name)
            line.set_label(series_name)

        splt.legend()

        # TO DO - better filename configuration for plots
        self.plot.savefig(filename)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号