python类gcf()的实例源码

publish.py 文件源码 项目:autoxd 作者: nessessary 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def gcf(self):
        return pl.gcf()
publish.py 文件源码 项目:autoxd 作者: nessessary 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def save(self):
        if len(self.figs) == 0:
            self.use_figure = True
            self.figs.append(pl.gcf())
        sPid = str(os.getpid())
        for i,fig in enumerate(self.figs):
            fname = self.path + self.name + "_" + sPid + '_' + str(i) +".png"
            if self.use_figure:
                fname = self.path + self.name + "_" + sPid + '_' + str(len(self.imgs)) + ".png"
            self.imgs.append(fname)            
            self.cur_img_fname = fname
            pl.savefig(fname, dpi=70)
figrc.py 文件源码 项目:tap 作者: mfouesneau 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def tight_layout():
    from matplotlib import get_backend
    from pylab import gcf
    if get_backend().lower() in ['agg', 'macosx']:
        gcf().set_tight_layout(True)
    else:
        plt.tight_layout()
utils.py 文件源码 项目:chainer-adversarial-autoencoder 作者: fukuta0614 项目源码 文件源码 阅读 19 收藏 0 点赞 0 评论 0
def visualize_10_2d_gaussian_prior(n_z, y_label, visualization_dir=None):
    z_batch = sample_z_from_n_2d_gaussian_mixture(len(y_label), n_z, y_label, 10, False)
    z_batch = z_batch.data

    fig = pylab.gcf()
    fig.set_size_inches(15, 12)
    pylab.clf()
    colors = ["#2103c8", "#0e960e", "#e40402", "#05aaa8", "#ac02ab", "#aba808", "#151515", "#94a169", "#bec9cd",
              "#6a6551"]
    for n in xrange(z_batch.shape[0]):
        result = pylab.scatter(z_batch[n, 0], z_batch[n, 1], c=colors[y_label[n]], s=40, marker="o",
                               edgecolors='none')

    classes = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]
    recs = []
    for i in range(0, len(colors)):
        recs.append(mpatches.Rectangle((0, 0), 1, 1, fc=colors[i]))

    ax = pylab.subplot(111)
    box = ax.get_position()
    ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
    ax.legend(recs, classes, loc="center left", bbox_to_anchor=(1.1, 0.5))
    pylab.xticks(pylab.arange(-4, 5))
    pylab.yticks(pylab.arange(-4, 5))
    pylab.xlabel("z1")
    pylab.ylabel("z2")
    if visualization_dir is not None:
        pylab.savefig("%s/10_2d-gaussian.png" % visualization_dir)
    pylab.show()
utils.py 文件源码 项目:chainer-adversarial-autoencoder 作者: fukuta0614 项目源码 文件源码 阅读 19 收藏 0 点赞 0 评论 0
def visualize_labeled_z(xp, model, x, y_label, visualization_dir, epoch, gpu=False):
    x = chainer.Variable(xp.asarray(x))
    z_batch = model.encode(x, test=True)
    z_batch.to_cpu()
    z_batch = z_batch.data
    fig = pylab.gcf()
    fig.set_size_inches(8.0, 8.0)
    pylab.clf()
    colors = ["#2103c8", "#0e960e", "#e40402", "#05aaa8", "#ac02ab", "#aba808", "#151515", "#94a169", "#bec9cd",
              "#6a6551"]
    for n in xrange(z_batch.shape[0]):
        result = pylab.scatter(z_batch[n, 0], z_batch[n, 1], c=colors[y_label[n]], s=40, marker="o",
                               edgecolors='none')

    classes = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]
    recs = []
    for i in range(0, len(colors)):
        recs.append(mpatches.Rectangle((0, 0), 1, 1, fc=colors[i]))

    ax = pylab.subplot(111)
    box = ax.get_position()
    ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
    ax.legend(recs, classes, loc="center left", bbox_to_anchor=(1.1, 0.5))
    pylab.xticks(pylab.arange(-4, 5))
    pylab.yticks(pylab.arange(-4, 5))
    pylab.xlabel("z1")
    pylab.ylabel("z2")
    pylab.savefig("{}/labeled_z_{}.png".format(visualization_dir, epoch))
    # pylab.show()
plotting.py 文件源码 项目:ugali 作者: DarkEnergySurvey 项目源码 文件源码 阅读 18 收藏 0 点赞 0 评论 0
def drawImage(self,ax=None,invert=True):
        if not ax: ax = plt.gca()

        if self.config['data']['survey']=='sdss':
            # Optical Image
            im = ugali.utils.plotting.getSDSSImage(**self.image_kwargs)
            # Flipping JPEG:
            # https://github.com/matplotlib/matplotlib/issues/101
            im = im[::-1]
            ax.annotate("SDSS Image",**self.label_kwargs)
        else: 
            im = ugali.utils.plotting.getDSSImage(**self.image_kwargs)
            im = im[::-1,::-1]
            ax.annotate("DSS Image",**self.label_kwargs)

        size=self.image_kwargs.get('radius',1.0)

        # Celestial coordinates
        x = np.linspace(-size,size,im.shape[0])
        y = np.linspace(-size,size,im.shape[1])
        xx, yy = np.meshgrid(x,y)

        #kwargs = dict(cmap='gray',interpolation='none')
        kwargs = dict(cmap='gray',coord='C')
        im = drawProjImage(xx,yy,im,**kwargs)

        try: plt.gcf().delaxes(ax.cax)
        except AttributeError: pass

        return im
plotting.py 文件源码 项目:ugali 作者: DarkEnergySurvey 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def drawMembersSpatial(self,data):
        ax = plt.gca()
        if isinstance(data,basestring):
            filename = data
            data = pyfits.open(filename)[1].data

        xmin, xmax = -0.25,0.25
        ymin, ymax = -0.25,0.25
        xx,yy = np.meshgrid(np.linspace(xmin,xmax),np.linspace(ymin,ymax))

        x_prob, y_prob = sphere2image(self.ra, self.dec, data['RA'], data['DEC'])

        sel = (x_prob > xmin)&(x_prob < xmax) & (y_prob > ymin)&(y_prob < ymax)
        sel_prob = data['PROB'][sel] > 5.e-2
        index_sort = numpy.argsort(data['PROB'][sel][sel_prob])

        plt.scatter(x_prob[sel][~sel_prob], y_prob[sel][~sel_prob], 
                      marker='o', s=2, c='0.75', edgecolor='none')
        sc = plt.scatter(x_prob[sel][sel_prob][index_sort], 
                         y_prob[sel][sel_prob][index_sort], 
                         c=data['PROB'][sel][sel_prob][index_sort], 
                         marker='o', s=10, edgecolor='none', cmap='jet', vmin=0., vmax=1.) # Spectral_r

        drawProjImage(xx,yy,None,coord='C')

        #ax.set_xlim(xmax, xmin)
        #ax.set_ylim(ymin, ymax)
        #plt.xlabel(r'$\Delta \alpha_{2000}\,(\deg)$')
        #plt.ylabel(r'$\Delta \delta_{2000}\,(\deg)$')
        plt.xticks([-0.2, 0., 0.2])
        plt.yticks([-0.2, 0., 0.2])

        divider = make_axes_locatable(ax)
        ax_cb = divider.new_horizontal(size="7%", pad=0.1)
        plt.gcf().add_axes(ax_cb)
        pylab.colorbar(sc, cax=ax_cb, orientation='vertical', ticks=[0, 0.2, 0.4, 0.6, 0.8, 1.0], label='Membership Probability')
        ax_cb.yaxis.tick_right()
diagnostics.py 文件源码 项目:fang 作者: rgrosse 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def plot(self):
        if self.vis is not None:
            pylab.figure('log probs')
            pylab.clf()
            plot_objfn(self.fe_info['main'], self.log_Z_info['main'], 'b', label='Raw')
            plot_objfn(self.fe_info['avg'], self.log_Z_info['avg'], 'r', label='Averaged')
            pylab.title('log probs')
            pylab.legend(loc='lower right')
            pylab.gcf().canvas.draw()

            pylab.figure('log probs (zoomed)')
            pylab.clf()
            plot_objfn(self.fe_info['main'], self.log_Z_info['main'], 'b', zoom=True, label='Raw')
            plot_objfn(self.fe_info['avg'], self.log_Z_info['avg'], 'r', label='Averaged')
            pylab.title('log probs (zoomed)')
            pylab.legend(loc='lower right')
            pylab.gcf().canvas.draw()

        if self.target_moments is not None:
            pylab.figure('moment matching objective')
            pylab.clf()
            plot_objfn(self.dp_info['main'], self.log_Z_info['main'], 'b', label='Raw')
            plot_objfn(self.dp_info['avg'], self.log_Z_info['avg'], 'r', label='Averaged')
            pylab.title('moment matching objective')
            pylab.legend(loc='lower right')
            pylab.gcf().canvas.draw()

            pylab.figure('moment matching objective (zoomed)')
            pylab.clf()
            plot_objfn(self.dp_info['main'], self.log_Z_info['main'], 'b', zoom=True, label='Raw')
            plot_objfn(self.dp_info['avg'], self.log_Z_info['avg'], 'r', label='Averaged')
            pylab.title('moment matching objective (zoomed)')
            pylab.legend(loc='lower right')
            pylab.gcf().canvas.draw()
from_scratch.py 文件源码 项目:fang 作者: rgrosse 项目源码 文件源码 阅读 43 收藏 0 点赞 0 评论 0
def after_step(self, rbm, trainer, i):
        it = i + 1

        save = it in self.expt.save_after
        display = it in self.expt.show_after

        if save:
            if self.expt.save_particles:
                storage.dump(trainer.fantasy_particles, self.expt.pcd_particles_file(it))
            storage.dump(rbm, self.expt.rbm_file(it))
            if hasattr(trainer, 'avg_rbm'):
                storage.dump(trainer.avg_rbm, self.expt.avg_rbm_file(it))
            storage.dump(time.time() - self.t0, self.expt.time_file(it))

        if 'particles' in self.subset and (save or display):
            fig = rbm_vis.show_particles(rbm, trainer.fantasy_particles, self.expt.dataset, display=display,
                                         figtitle='PCD particles ({} updates)'.format(it))
            if display:
                pylab.gcf().canvas.draw()
            if save:
                misc.save_image(fig, self.expt.pcd_particles_figure_file(it))

        if 'gibbs_chains' in self.subset and (save or display):
            fig = diagnostics.show_chains(rbm, trainer.fantasy_particles, self.expt.dataset, display=display,
                                          figtitle='Gibbs chains (iteration {})'.format(it))
            if save:
                misc.save_image(fig, self.expt.gibbs_chains_figure_file(it))

        if 'objective' in self.subset:
            self.log_prob_tracker.update(rbm, trainer.fantasy_particles)

        if display:
            pylab.gcf().canvas.draw()
plot.py 文件源码 项目:unrolled-gan 作者: musyoku 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def plot_scatter(data, dir=None, filename="scatter", color="blue"):
    if dir is None:
        raise Exception()
    try:
        os.mkdir(dir)
    except:
        pass
    fig = pylab.gcf()
    fig.set_size_inches(16.0, 16.0)
    pylab.clf()
    pylab.scatter(data[:, 0], data[:, 1], s=20, marker="o", edgecolors="none", color=color)
    pylab.xlim(-4, 4)
    pylab.ylim(-4, 4)
    pylab.savefig("{}/{}.png".format(dir, filename))
plot_true.py 文件源码 项目:unrolled-gan 作者: musyoku 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def plot_scatter(data, dir=None, filename="scatter", color="blue"):
    if dir is None:
        raise Exception()
    try:
        os.mkdir(dir)
    except:
        pass
    fig = pylab.gcf()
    fig.set_size_inches(16.0, 16.0)
    pylab.clf()
    pylab.scatter(data[:, 0], data[:, 1], s=20, marker="o", edgecolors="none", color=color)
    pylab.xlim(-4, 4)
    pylab.ylim(-4, 4)
    pylab.savefig("{}/{}".format(dir, filename))
plot.py 文件源码 项目:LSGAN 作者: musyoku 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def plot_scatter(data, dir=None, filename="scatter", color="blue"):
    if dir is None:
        raise Exception()
    try:
        os.mkdir(dir)
    except:
        pass
    fig = pylab.gcf()
    fig.set_size_inches(16.0, 16.0)
    pylab.clf()
    pylab.scatter(data[:, 0], data[:, 1], s=20, marker="o", edgecolors="none", color=color)
    pylab.xlim(-4, 4)
    pylab.ylim(-4, 4)
    pylab.savefig("{}/{}.png".format(dir, filename))
plot_true.py 文件源码 项目:LSGAN 作者: musyoku 项目源码 文件源码 阅读 19 收藏 0 点赞 0 评论 0
def plot_scatter(data, dir=None, filename="scatter", color="blue"):
    if dir is None:
        raise Exception()
    try:
        os.mkdir(dir)
    except:
        pass
    fig = pylab.gcf()
    fig.set_size_inches(16.0, 16.0)
    pylab.clf()
    pylab.scatter(data[:, 0], data[:, 1], s=20, marker="o", edgecolors="none", color=color)
    pylab.xlim(-4, 4)
    pylab.ylim(-4, 4)
    pylab.savefig("{}/{}".format(dir, filename))
visualizer.py 文件源码 项目:adgm 作者: musyoku 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def tile_rgb_images(x, dir=None, filename="x"):
    if dir is None:
        raise Exception()
    try:
        os.mkdir(dir)
    except:
        pass
    fig = pylab.gcf()
    fig.set_size_inches(16.0, 16.0)
    pylab.clf()
    for m in range(100):
        pylab.subplot(10, 10, m + 1)
        pylab.imshow(np.clip(x[m], 0, 1), interpolation="none")
        pylab.axis("off")
    pylab.savefig("{}/{}.png".format(dir, filename))
util.py 文件源码 项目:variational-autoencoder 作者: musyoku 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def visualize_z(z_batch, dir=None):
    if dir is None:
        raise Exception()
    try:
        os.mkdir(dir)
    except:
        pass
    fig = pylab.gcf()
    fig.set_size_inches(20.0, 16.0)
    pylab.clf()
    for n in xrange(z_batch.shape[0]):
        result = pylab.scatter(z_batch[n, 0], z_batch[n, 1], s=40, marker="o", edgecolors='none')
    pylab.xlabel("z1")
    pylab.ylabel("z2")
    pylab.savefig("%s/latent_code.png" % dir)
vis_corex.py 文件源码 项目:bio_corex 作者: gregversteeg 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def plot_rels(data, labels=None, colors=None, outfile="rels", latent=None, alpha=0.8):
    ns, n = data.shape
    if labels is None:
        labels = list(map(str, range(n)))
    ncol = 5
    # ncol = 4
    nrow = int(np.ceil(float(n * (n - 1) / 2) / ncol))
    #nrow=1
    #pylab.rcParams.update({'figure.autolayout': True})
    fig, axs = pylab.subplots(nrow, ncol)
    fig.set_size_inches(5 * ncol, 5 * nrow)
    #fig.set_canvas(pylab.gcf().canvas)
    pairs = list(combinations(range(n), 2))  #[:4]
    pairs = sorted(pairs, key=lambda q: q[0]**2+q[1]**2)  # Puts stronger relationships first
    if colors is not None:
        colors = (colors - np.min(colors)) / (np.max(colors) - np.min(colors)).clip(1e-7)

    for ax, pair in zip(axs.flat, pairs):
        if latent is None:
            ax.scatter(data[:, pair[0]], data[:, pair[1]], marker='.', edgecolors='none', alpha=alpha)
        else:
            # cs = 'rgbcmykrgbcmyk'
            markers = 'x+.o,<>^^<>,+x.'
            for j, ind in enumerate(np.unique(latent)):
                inds = (latent == ind)
                ax.scatter(data[inds, pair[0]], data[inds, pair[1]], c=colors[inds], cmap=pylab.get_cmap("jet"),
                           marker=markers[j], alpha=0.5, edgecolors='none', vmin=0, vmax=1)

        ax.set_xlabel(shorten(labels[pair[0]]))
        ax.set_ylabel(shorten(labels[pair[1]]))

    for ax in axs.flat[axs.size - 1:len(pairs) - 1:-1]:
        ax.scatter(data[:, 0], data[:, 1], marker='.')

    pylab.rcParams['font.size'] = 12  #6
    pylab.draw()
    #fig.set_tight_layout(True)
    fig.tight_layout()
    for ax in axs.flat[axs.size - 1:len(pairs) - 1:-1]:
        ax.set_visible(False)
    filename = outfile + '.png'
    if not os.path.exists(os.path.dirname(filename)):
        os.makedirs(os.path.dirname(filename))
    fig.savefig(outfile + '.png')  #df')
    pylab.close('all')
    return True
visualize.py 文件源码 项目:adversarial-autoencoder 作者: musyoku 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def plot_analogy():
    dataset_train, dataset_test = chainer.datasets.get_mnist()
    images_train, labels_train = dataset_train._datasets
    images_test, labels_test = dataset_test._datasets
    dataset_indices = np.arange(0, len(images_test))
    np.random.shuffle(dataset_indices)

    model = Model()
    assert model.load("model.hdf5")

    # normalize
    images_train = (images_train - 0.5) * 2
    images_test = (images_test - 0.5) * 2

    num_analogies = 10
    pylab.gray()

    batch_indices = dataset_indices[:num_analogies]
    x_batch = images_test[batch_indices]
    y_batch = labels_test[batch_indices]
    y_onehot_batch = onehot(y_batch)

    with chainer.no_backprop_mode() and chainer.using_config("train", False):
        z_batch = model.encode_x_yz(x_batch)[1].data

        # plot original image on the left
        x_batch = (x_batch + 1.0) / 2.0
        for m in range(num_analogies):
            pylab.subplot(num_analogies, 10 + 2, m * 12 + 1)
            pylab.imshow(x_batch[m].reshape((28, 28)), interpolation="none")
            pylab.axis("off")

        all_y = np.identity(10, dtype=np.float32)
        for m in range(num_analogies):
            # copy z_batch as many as the number of classes
            fixed_z = np.repeat(z_batch[m].reshape(1, -1), 10, axis=0)
            gen_x = model.decode_yz_x(all_y, fixed_z).data
            gen_x = (gen_x + 1.0) / 2.0
            # plot images generated from each label
            for n in range(10):
                pylab.subplot(num_analogies, 10 + 2, m * 12 + 3 + n)
                pylab.imshow(gen_x[n].reshape((28, 28)), interpolation="none")
                pylab.axis("off")

    fig = pylab.gcf()
    fig.set_size_inches(num_analogies, 10)
    pylab.savefig("analogy.png")
visualize.py 文件源码 项目:adversarial-autoencoder 作者: musyoku 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def plot_analogy():
    dataset_train, dataset_test = chainer.datasets.get_mnist()
    images_train, labels_train = dataset_train._datasets
    images_test, labels_test = dataset_test._datasets
    dataset_indices = np.arange(0, len(images_test))
    np.random.shuffle(dataset_indices)

    model = Model()
    assert model.load("model.hdf5")

    # normalize
    images_train = (images_train - 0.5) * 2
    images_test = (images_test - 0.5) * 2

    num_analogies = 10
    pylab.gray()

    batch_indices = dataset_indices[:num_analogies]
    x_batch = images_test[batch_indices]
    y_batch = labels_test[batch_indices]
    y_onehot_batch = onehot(y_batch)

    with chainer.no_backprop_mode() and chainer.using_config("train", False):
        z_batch = model.encode_x_yz(x_batch)[1].data

        # plot original image on the left
        x_batch = (x_batch + 1.0) / 2.0
        for m in range(num_analogies):
            pylab.subplot(num_analogies, 10 + 2, m * 12 + 1)
            pylab.imshow(x_batch[m].reshape((28, 28)), interpolation="none")
            pylab.axis("off")

        all_y = np.identity(10, dtype=np.float32)
        for m in range(num_analogies):
            # copy z_batch as many as the number of classes
            fixed_z = np.repeat(z_batch[m].reshape(1, -1), 10, axis=0)
            representation = model.encode_yz_representation(all_y, fixed_z)
            gen_x = model.decode_representation_x(representation).data
            gen_x = (gen_x + 1.0) / 2.0
            # plot images generated from each label
            for n in range(10):
                pylab.subplot(num_analogies, 10 + 2, m * 12 + 3 + n)
                pylab.imshow(gen_x[n].reshape((28, 28)), interpolation="none")
                pylab.axis("off")

    fig = pylab.gcf()
    fig.set_size_inches(num_analogies, 10)
    pylab.savefig("analogy.png")
visualize.py 文件源码 项目:adversarial-autoencoder 作者: musyoku 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def plot_analogy():
    dataset_train, dataset_test = chainer.datasets.get_mnist()
    images_train, labels_train = dataset_train._datasets
    images_test, labels_test = dataset_test._datasets
    dataset_indices = np.arange(0, len(images_test))
    np.random.shuffle(dataset_indices)

    model = Model()
    assert model.load("model.hdf5")

    # normalize
    images_train = (images_train - 0.5) * 2
    images_test = (images_test - 0.5) * 2

    num_analogies = 10
    pylab.gray()

    batch_indices = dataset_indices[:num_analogies]
    x_batch = images_test[batch_indices]
    y_batch = labels_test[batch_indices]
    y_onehot_batch = onehot(y_batch)

    with chainer.no_backprop_mode() and chainer.using_config("train", False):
        z_batch = model.encode_x_z(x_batch).data

        # plot original image on the left
        x_batch = (x_batch + 1.0) / 2.0
        for m in range(num_analogies):
            pylab.subplot(num_analogies, 10 + 2, m * 12 + 1)
            pylab.imshow(x_batch[m].reshape((28, 28)), interpolation="none")
            pylab.axis("off")

        all_y = np.identity(10, dtype=np.float32)
        for m in range(num_analogies):
            # copy z_batch as many as the number of classes
            fixed_z = np.repeat(z_batch[m].reshape(1, -1), 10, axis=0)
            gen_x = model.decode_yz_x(all_y, fixed_z).data
            gen_x = (gen_x + 1.0) / 2.0
            # plot images generated from each label
            for n in range(10):
                pylab.subplot(num_analogies, 10 + 2, m * 12 + 3 + n)
                pylab.imshow(gen_x[n].reshape((28, 28)), interpolation="none")
                pylab.axis("off")

    fig = pylab.gcf()
    fig.set_size_inches(num_analogies, 10)
    pylab.savefig("analogy.png")
visualize.py 文件源码 项目:adversarial-autoencoder 作者: musyoku 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def plot_clusters():
    dataset_train, dataset_test = chainer.datasets.get_mnist()
    images_train, labels_train = dataset_train._datasets
    images_test, labels_test = dataset_test._datasets
    dataset_indices = np.arange(0, len(images_test))
    np.random.shuffle(dataset_indices)

    model = Model()
    assert model.load("model.hdf5")

    # normalize
    images_train = (images_train - 0.5) * 2
    images_test = (images_test - 0.5) * 2

    num_clusters = model.ndim_y
    num_plots_per_cluster = 11
    image_width = 28
    image_height = 28
    ndim_x = image_width * image_height
    pylab.gray()

    with chainer.no_backprop_mode() and chainer.using_config("train", False):
        # plot cluster head
        head_y = np.identity(model.ndim_y, dtype=np.float32)
        zero_z = np.zeros((model.ndim_y, model.ndim_z), dtype=np.float32)
        head_x = model.decode_yz_x(head_y, zero_z).data
        head_x = (head_x + 1.0) / 2.0
        for n in range(num_clusters):
            pylab.subplot(num_clusters, num_plots_per_cluster + 2, n * (num_plots_per_cluster + 2) + 1)
            pylab.imshow(head_x[n].reshape((image_width, image_height)), interpolation="none")
            pylab.axis("off")

        # plot elements in cluster
        counts = [0 for i in range(num_clusters)]
        indices = np.arange(len(images_test))
        np.random.shuffle(indices)
        batchsize = 500

        i = 0
        x_batch = np.zeros((batchsize, ndim_x), dtype=np.float32)
        for n in range(len(images_test) // batchsize):
            for b in range(batchsize):
                x_batch[b] = images_test[indices[i]]
                i += 1
            y_batch = model.encode_x_yz(x_batch)[0].data
            labels = np.argmax(y_batch, axis=1)
            for m in range(labels.size):
                cluster = int(labels[m])
                counts[cluster] += 1
                if counts[cluster] <= num_plots_per_cluster:
                    x = (x_batch[m] + 1.0) / 2.0
                    pylab.subplot(num_clusters, num_plots_per_cluster + 2, cluster * (num_plots_per_cluster + 2) + 2 + counts[cluster])
                    pylab.imshow(x.reshape((image_width, image_height)), interpolation="none")
                    pylab.axis("off")

        fig = pylab.gcf()
        fig.set_size_inches(num_plots_per_cluster, num_clusters)
        pylab.savefig("clusters.png")


问题


面经


文章

微信
公众号

扫码关注公众号