def grid_plot2d(Q, P, data_loader, params):
Q.eval()
P.eval()
cuda = params['cuda']
z1 = Variable(torch.from_numpy(np.arange(-10, 10, 1.5).astype('float32')))
z2 = Variable(torch.from_numpy(np.arange(-10, 10, 1.5).astype('float32')))
if cuda:
z1, z2 = z1.cuda(), z2.cuda()
nx, ny = len(z1), len(z2)
plt.subplot()
gs = gridspec.GridSpec(nx, ny, hspace=0.05, wspace=0.05)
for i, g in enumerate(gs):
z = torch.cat((z1[i / ny], z2[i % nx])).resize(1, 2)
x = P(z)
ax = plt.subplot(g)
img = np.array(x.data.tolist()).reshape(28, 28)
ax.imshow(img, )
ax.set_xticks([])
ax.set_yticks([])
ax.set_aspect('auto')
评论列表
文章目录