def plot_examples(nade, dataset, shape, name, rows=5, cols=10):
#Show some samples
images = list()
for row in xrange(rows):
for i in xrange(cols):
nade.setup_n_orderings(n=1)
sample = dataset.sample_data(1)[0].T
dens = nade.logdensity(sample)
images.append((sample, dens))
images.sort(key=lambda x: -x[1])
plt.figure(figsize=(0.5*cols,0.5*rows), dpi=100)
plt.gray()
for row in xrange(rows):
for col in xrange(cols):
i = row*cols+col
sample, dens = images[i]
plt.subplot(rows, cols, i+1)
plot_sample(np.resize(sample, np.prod(shape)).reshape(shape), shape, origin="upper")
plt.subplots_adjust(left=0.01, right=0.99, top=0.99, bottom=0.01, hspace=0.04, wspace=0.04)
type_1_font()
plt.savefig(os.path.join(DESTINATION_PATH, name))
评论列表
文章目录