def show_chains(rbm, state, dataset, num_particles=20, num_samples=20, show_every=10, display=True,
figname='Gibbs chains', figtitle='Gibbs chains'):
samples = gnp.zeros((num_particles, num_samples, state.v.shape[1]))
state = state[:num_particles, :, :]
for i in range(num_samples):
samples[:, i, :] = rbm.vis_expectations(state.h)
for j in range(show_every):
state = rbm.step(state)
npix = dataset.num_rows * dataset.num_cols
rows = [vm.hjoin([samples[i, j, :npix].reshape((dataset.num_rows, dataset.num_cols)).as_numpy_array()
for j in range(num_samples)],
normalize=False)
for i in range(num_particles)]
grid = vm.vjoin(rows, normalize=False)
if display:
pylab.figure(figname)
pylab.matshow(grid, cmap='gray', fignum=False)
pylab.title(figtitle)
pylab.gcf().canvas.draw()
return grid
评论列表
文章目录