viz.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:AAE_pytorch 作者: fducau 项目源码 文件源码
def grid_plot2d(Q, P, data_loader, params):
    Q.eval()
    P.eval()

    cuda = params['cuda']

    z1 = Variable(torch.from_numpy(np.arange(-10, 10, 1.5).astype('float32')))
    z2 = Variable(torch.from_numpy(np.arange(-10, 10, 1.5).astype('float32')))
    if cuda:
        z1, z2 = z1.cuda(), z2.cuda()

    nx, ny = len(z1), len(z2)
    plt.subplot()
    gs = gridspec.GridSpec(nx, ny, hspace=0.05, wspace=0.05)

    for i, g in enumerate(gs):
        z = torch.cat((z1[i / ny], z2[i % nx])).resize(1, 2)
        x = P(z)

        ax = plt.subplot(g)
        img = np.array(x.data.tolist()).reshape(28, 28)
        ax.imshow(img, )
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_aspect('auto')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号