def plotFields(layer,fieldShape=None,channel=None,figOffset=1,cmap=None,padding=0.01):
# Receptive Fields Summary
try:
W = layer.W
except:
W = layer
wp = W.eval().transpose();
if len(np.shape(wp)) < 4: # Fully connected layer, has no shape
fields = np.reshape(wp,list(wp.shape[0:-1])+fieldShape)
else: # Convolutional layer already has shape
features, channels, iy, ix = np.shape(wp)
if channel is not None:
fields = wp[:,channel,:,:]
else:
fields = np.reshape(wp,[features*channels,iy,ix])
perRow = int(math.floor(math.sqrt(fields.shape[0])))
perColumn = int(math.ceil(fields.shape[0]/float(perRow)))
fig = mpl.figure(figOffset); mpl.clf()
# Using image grid
from mpl_toolkits.axes_grid1 import ImageGrid
grid = ImageGrid(fig,111,nrows_ncols=(perRow,perColumn),axes_pad=padding,cbar_mode='single')
for i in range(0,np.shape(fields)[0]):
im = grid[i].imshow(fields[i],cmap=cmap);
grid.cbar_axes[0].colorbar(im)
mpl.title('%s Receptive Fields' % layer.name)
# old way
# fields2 = np.vstack([fields,np.zeros([perRow*perColumn-fields.shape[0]] + list(fields.shape[1:]))])
# tiled = []
# for i in range(0,perColumn*perRow,perColumn):
# tiled.append(np.hstack(fields2[i:i+perColumn]))
#
# tiled = np.vstack(tiled)
# mpl.figure(figOffset); mpl.clf(); mpl.imshow(tiled,cmap=cmap); mpl.title('%s Receptive Fields' % layer.name); mpl.colorbar();
mpl.figure(figOffset+1); mpl.clf(); mpl.imshow(np.sum(np.abs(fields),0),cmap=cmap); mpl.title('%s Total Absolute Input Dependency' % layer.name); mpl.colorbar()
python类ImageGrid()的实例源码
def plotFields(layer,fieldShape=None,channel=None,maxFields=25,figName='ReceptiveFields',cmap=None,padding=0.01):
# Receptive Fields Summary
W = layer.W
wp = W.eval().transpose();
if len(np.shape(wp)) < 4: # Fully connected layer, has no shape
fields = np.reshape(wp,list(wp.shape[0:-1])+fieldShape)
else: # Convolutional layer already has shape
features, channels, iy, ix = np.shape(wp)
if channel is not None:
fields = wp[:,channel,:,:]
else:
fields = np.reshape(wp,[features*channels,iy,ix])
fieldsN = min(fields.shape[0],maxFields)
perRow = int(math.floor(math.sqrt(fieldsN)))
perColumn = int(math.ceil(fieldsN/float(perRow)))
fig = mpl.figure(figName); mpl.clf()
# Using image grid
from mpl_toolkits.axes_grid1 import ImageGrid
grid = ImageGrid(fig,111,nrows_ncols=(perRow,perColumn),axes_pad=padding,cbar_mode='single')
for i in range(0,fieldsN):
im = grid[i].imshow(fields[i],cmap=cmap);
grid.cbar_axes[0].colorbar(im)
mpl.title('%s Receptive Fields' % layer.name)
# old way
# fields2 = np.vstack([fields,np.zeros([perRow*perColumn-fields.shape[0]] + list(fields.shape[1:]))])
# tiled = []
# for i in range(0,perColumn*perRow,perColumn):
# tiled.append(np.hstack(fields2[i:i+perColumn]))
#
# tiled = np.vstack(tiled)
# mpl.figure(figOffset); mpl.clf(); mpl.imshow(tiled,cmap=cmap); mpl.title('%s Receptive Fields' % layer.name); mpl.colorbar();
mpl.figure(figName+' Total'); mpl.clf(); mpl.imshow(np.sum(np.abs(fields),0),cmap=cmap); mpl.title('%s Total Absolute Input Dependency' % layer.name); mpl.colorbar()
def plot_image_grid(images, num_rows, num_cols, save_path=None):
"""Plots images in a grid.
Parameters
----------
images : numpy.ndarray
Images to display, with shape
``(num_rows * num_cols, num_channels, height, width)``.
num_rows : int
Number of rows for the image grid.
num_cols : int
Number of columns for the image grid.
save_path : str, optional
Where to save the image grid. Defaults to ``None``,
which causes the grid to be displayed on screen.
"""
figure = pyplot.figure()
grid = ImageGrid(figure, 111, (num_rows, num_cols), axes_pad=0.1)
for image, axis in zip(images, grid):
axis.imshow(image.transpose(1, 2, 0), interpolation='nearest')
axis.set_yticklabels(['' for _ in range(image.shape[1])])
axis.set_xticklabels(['' for _ in range(image.shape[2])])
axis.axis('off')
if save_path is None:
pyplot.show()
else:
pyplot.savefig(save_path, transparent=True, bbox_inches='tight')
def plot_image_grid(images, num_rows, num_cols, save_path=None):
"""Plots images in a grid.
Parameters
----------
images : numpy.ndarray
Images to display, with shape
``(num_rows * num_cols, num_channels, height, width)``.
num_rows : int
Number of rows for the image grid.
num_cols : int
Number of columns for the image grid.
save_path : str, optional
Where to save the image grid. Defaults to ``None``,
which causes the grid to be displayed on screen.
"""
figure = pyplot.figure()
grid = ImageGrid(figure, 111, (num_rows, num_cols), axes_pad=0.1)
for image, axis in zip(images, grid):
axis.imshow(image.transpose(1, 2, 0), interpolation='nearest')
axis.set_yticklabels(['' for _ in range(image.shape[1])])
axis.set_xticklabels(['' for _ in range(image.shape[2])])
axis.axis('off')
if save_path is None:
pyplot.show()
else:
pyplot.savefig(save_path, transparent=True, bbox_inches='tight',dpi=212)
pyplot.close()
def save_imshow_grid(images, logs_dir, filename, shape):
"""
Plot images in a grid of a given shape.
"""
fig = plt.figure(1)
grid = ImageGrid(fig, 111, nrows_ncols=shape, axes_pad=0.05)
size = shape[0] * shape[1]
for i in trange(size, desc="Saving images"):
grid[i].axis('off')
grid[i].imshow(images[i])
plt.savefig(os.path.join(logs_dir, filename))
def _get_grids(fig, rows, cols, axes_pad=0):
grids = []
for row in range(rows):
grid_id = int("%d%d%d" % (rows, 1, row + 1))
grid = ImageGrid(fig, grid_id,
nrows_ncols=(1, cols),
axes_pad=(0.05, axes_pad),
share_all=True,
cbar_location="right",
cbar_mode="single",
cbar_size="10%",
cbar_pad="5%")
grids.append(grid)
return grids
def save_imshow_grid(images, logs_dir, filename, shape):
"""
Plot images in a grid of a given shape.
"""
pickle.dump(images, open(os.path.join(logs_dir, "image.pk"), "wb"))
fig = plt.figure(1)
grid = ImageGrid(fig, 111, nrows_ncols=shape, axes_pad=0.05)
size = shape[0] * shape[1]
for i in trange(size, desc="Saving images"):
grid[i].axis('off')
grid[i].imshow(images[i])
Image.fromarray(images[i]).save(os.path.join(logs_dir,str(i)),"jpeg")
plt.savefig(os.path.join(logs_dir, filename))
def main(argv=None):
input.init_dataset_constants()
num_images = GRID[0] * GRID[1]
FLAGS.batch_size = num_images
with tf.Graph().as_default():
g_template = model.generator_template()
z = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.z_size])
#np.random.seed(1337) # generate same random numbers each time
noise = np.random.normal(size=(FLAGS.batch_size, FLAGS.z_size))
with pt.defaults_scope(phase=pt.Phase.test):
gen_images_op, _ = pt.construct_all(g_template, input=z)
sess = tf.Session()
init_variables(sess)
gen_images, = sess.run([gen_images_op], feed_dict={z: noise})
gen_images = (gen_images + 1) / 2
sess.close()
fig = plt.figure(1)
grid = ImageGrid(fig, 111,
nrows_ncols=GRID,
axes_pad=0.1)
for i in xrange(num_images):
im = gen_images[i]
axis = grid[i]
axis.axis('off')
axis.imshow(im)
plt.show()
fig.savefig('montage.png', dpi=100, bbox_inches='tight')
def generate_images_line_save(self, line_segment, query_id, image_original_space=None):
"""
ID of query point from which query line was generated is
added to the filename of the saved line query.
:param line_segment:
:param query_id:
:return:
"""
try:
if image_original_space is not None:
x = self.generative_model.decode(image_original_space.T)
else:
x = self.generative_model.decode(to_vector(self.dataset.data["features"][
query_id]).T) # comes from dataset.data["features"], so is already in original space in which ALI operates.
save_path = os.path.join(self.save_path_queries, "pointquery_%d_%d.png" % (self.n_queries + 1, query_id))
if x.shape[1] == 1:
plt.imsave(save_path, x[0, 0, :, :], cmap=cm.Greys)
else:
plt.imsave(save_path, x[0, :, :, :].transpose(1, 2, 0), cmap=cm.Greys_r)
decoded_images = self.generative_model.decode(self.dataset.scaling_transformation.inverse_transform(
line_segment)) # Transform to original space, in which ALI operates.
figure = plt.figure()
grid = ImageGrid(figure, 111, (1, decoded_images.shape[0]), axes_pad=0.1)
for image, axis in zip(decoded_images, grid):
if image.shape[0] == 1:
axis.imshow(image[0, :, :].squeeze(),
cmap=cm.Greys, interpolation='nearest')
else:
axis.imshow(image.transpose(1, 2, 0).squeeze(),
cmap=cm.Greys_r, interpolation='nearest')
axis.set_yticklabels(['' for _ in range(image.shape[1])])
axis.set_xticklabels(['' for _ in range(image.shape[2])])
axis.axis('off')
save_path = os.path.join(self.save_path_queries, "linequery_%d_%d.pdf" % (self.n_queries + 1, query_id))
plt.savefig(save_path, transparent=True, bbox_inches='tight')
except Exception as e:
print "EXCEPTION:", traceback.format_exc()
raise e