def plot_samples(nade, shape, name, rows=5, cols=10):
#Show some samples
images = list()
for row in xrange(rows):
for i in xrange(cols):
nade.setup_n_orderings(n=1)
sample = nade.sample(1)[:,0]
dens = nade.logdensity(sample[:, np.newaxis])
images.append((sample, dens))
images.sort(key=lambda x: -x[1])
plt.figure(figsize=(0.5*cols,0.5*rows), dpi=100)
plt.gray()
for row in xrange(rows):
for col in xrange(cols):
i = row*cols+col
sample, dens = images[i]
plt.subplot(rows, cols, i+1)
plot_sample(np.resize(sample, np.prod(shape)).reshape(shape), shape, origin="upper")
plt.subplots_adjust(left=0.01, right=0.99, top=0.99, bottom=0.01, hspace=0.04, wspace=0.04)
type_1_font()
plt.savefig(os.path.join(DESTINATION_PATH, name))
#plt.show()
评论列表
文章目录