def plot_norm_points(self, Inputs_N, e, Perms, scales, fig=1):
input = Inputs_N[0][0].data.cpu().numpy()
e = torch.sort(e, 1)[0][0].data.cpu().numpy()
Perms = [perm[0].data.cpu().numpy() for perm in Perms]
plt.figure(fig)
plt.clf()
ee = e.copy()
for i, perm in enumerate(Perms):
plt.subplot(1, len(Perms), i + 1)
colors = cm.rainbow(np.linspace(0, 1, 2 ** (scales - i)))
perm = perm[np.where(perm > 0)[0]] - 1
points = input[perm]
e_scale = ee[perm]
for node in xrange(2 ** (scales - i)):
ind = np.where(e_scale == node)[0]
pts = points[ind]
plt.scatter(pts[:, 0], pts[:, 1], c=colors[node])
ee //= 2
path = os.path.join(self.path, 'visualize_example.png')
plt.savefig(path)
评论列表
文章目录