query_strategy.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:ActiveBoundary 作者: MiriamHu 项目源码 文件源码
def generate_images_line_save(self, line_segment, query_id, image_original_space=None):
        """
        ID of query point from which query line was generated is
        added to the filename of the saved line query.
        :param line_segment:
        :param query_id:
        :return:
        """
        try:
            if image_original_space is not None:
                x = self.generative_model.decode(image_original_space.T)
            else:
                x = self.generative_model.decode(to_vector(self.dataset.data["features"][
                                                               query_id]).T)  # comes from dataset.data["features"], so is already in original space in which ALI operates.
            save_path = os.path.join(self.save_path_queries, "pointquery_%d_%d.png" % (self.n_queries + 1, query_id))
            if x.shape[1] == 1:
                plt.imsave(save_path, x[0, 0, :, :], cmap=cm.Greys)
            else:
                plt.imsave(save_path, x[0, :, :, :].transpose(1, 2, 0), cmap=cm.Greys_r)

            decoded_images = self.generative_model.decode(self.dataset.scaling_transformation.inverse_transform(
                line_segment))  # Transform to original space, in which ALI operates.
            figure = plt.figure()
            grid = ImageGrid(figure, 111, (1, decoded_images.shape[0]), axes_pad=0.1)
            for image, axis in zip(decoded_images, grid):
                if image.shape[0] == 1:
                    axis.imshow(image[0, :, :].squeeze(),
                                cmap=cm.Greys, interpolation='nearest')
                else:
                    axis.imshow(image.transpose(1, 2, 0).squeeze(),
                                cmap=cm.Greys_r, interpolation='nearest')
                axis.set_yticklabels(['' for _ in range(image.shape[1])])
                axis.set_xticklabels(['' for _ in range(image.shape[2])])
                axis.axis('off')
            save_path = os.path.join(self.save_path_queries, "linequery_%d_%d.pdf" % (self.n_queries + 1, query_id))
            plt.savefig(save_path, transparent=True, bbox_inches='tight')
        except Exception as e:
            print "EXCEPTION:", traceback.format_exc()
            raise e
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号