def after_step(self, rbm, trainer, i):
it = i + 1
save = it in self.expt.save_after
display = it in self.expt.show_after
if save:
if self.expt.save_particles:
storage.dump(trainer.fantasy_particles, self.expt.pcd_particles_file(it))
storage.dump(rbm, self.expt.rbm_file(it))
if hasattr(trainer, 'avg_rbm'):
storage.dump(trainer.avg_rbm, self.expt.avg_rbm_file(it))
storage.dump(time.time() - self.t0, self.expt.time_file(it))
if 'particles' in self.subset and (save or display):
fig = rbm_vis.show_particles(rbm, trainer.fantasy_particles, self.expt.dataset, display=display,
figtitle='PCD particles ({} updates)'.format(it))
if display:
pylab.gcf().canvas.draw()
if save:
misc.save_image(fig, self.expt.pcd_particles_figure_file(it))
if 'gibbs_chains' in self.subset and (save or display):
fig = diagnostics.show_chains(rbm, trainer.fantasy_particles, self.expt.dataset, display=display,
figtitle='Gibbs chains (iteration {})'.format(it))
if save:
misc.save_image(fig, self.expt.gibbs_chains_figure_file(it))
if 'objective' in self.subset:
self.log_prob_tracker.update(rbm, trainer.fantasy_particles)
if display:
pylab.gcf().canvas.draw()
评论列表
文章目录