def generate_images_line_save(self, line_segment, query_id, image_original_space=None):
"""
ID of query point from which query line was generated is
added to the filename of the saved line query.
:param line_segment:
:param query_id:
:return:
"""
try:
if image_original_space is not None:
x = self.generative_model.decode(image_original_space.T)
else:
x = self.generative_model.decode(to_vector(self.dataset.data["features"][
query_id]).T) # comes from dataset.data["features"], so is already in original space in which ALI operates.
save_path = os.path.join(self.save_path_queries, "pointquery_%d_%d.png" % (self.n_queries + 1, query_id))
if x.shape[1] == 1:
plt.imsave(save_path, x[0, 0, :, :], cmap=cm.Greys)
else:
plt.imsave(save_path, x[0, :, :, :].transpose(1, 2, 0), cmap=cm.Greys_r)
decoded_images = self.generative_model.decode(self.dataset.scaling_transformation.inverse_transform(
line_segment)) # Transform to original space, in which ALI operates.
figure = plt.figure()
grid = ImageGrid(figure, 111, (1, decoded_images.shape[0]), axes_pad=0.1)
for image, axis in zip(decoded_images, grid):
if image.shape[0] == 1:
axis.imshow(image[0, :, :].squeeze(),
cmap=cm.Greys, interpolation='nearest')
else:
axis.imshow(image.transpose(1, 2, 0).squeeze(),
cmap=cm.Greys_r, interpolation='nearest')
axis.set_yticklabels(['' for _ in range(image.shape[1])])
axis.set_xticklabels(['' for _ in range(image.shape[2])])
axis.axis('off')
save_path = os.path.join(self.save_path_queries, "linequery_%d_%d.pdf" % (self.n_queries + 1, query_id))
plt.savefig(save_path, transparent=True, bbox_inches='tight')
except Exception as e:
print "EXCEPTION:", traceback.format_exc()
raise e
评论列表
文章目录