def plot_results(self, results, xloc, color, ls, label):
iter_counts = sorted(set([it for it, av in results.keys() if av == self.average]))
sorted_results = [results[it, self.average] for it in iter_counts]
avg = np.array([r.train_logprob() for r in sorted_results])
if hasattr(r, 'train_logprob_interval'):
lower = np.array([r.train_logprob_interval()[0] for r in sorted_results])
upper = np.array([r.train_logprob_interval()[1] for r in sorted_results])
if self.logscale:
plot_cmd = pylab.semilogx
else:
plot_cmd = pylab.plot
xloc = xloc[:len(avg)]
lw = 2.
if label not in self.labels:
plot_cmd(xloc, avg, color=color, ls=ls, lw=lw, label=label)
else:
plot_cmd(xloc, avg, color=color, ls=ls, lw=lw)
self.labels.add(label)
pylab.xticks(fontsize='xx-large')
pylab.yticks(fontsize='xx-large')
try:
pylab.errorbar(xloc, (lower+upper)/2., yerr=(upper-lower)/2., fmt='', ls='None', ecolor=color)
except:
pass
评论列表
文章目录