def __init__(self, savefolder, imgdim, args, network):
super(ManifoldVisualizer, self).__init__(savefolder, imgdim, args)
self.network = network
self.name = "manifold"
self.parts = args.parts
z_dim = [int(np.prod(self.network.code_dims))]
self.flat_flag = z_dim[0] >= 2 * self.parts
self.hierachical_flag = z_dim[0] >= self.parts and len(z_dim) > 1 and z_dim[1] >= 2
assert self.flat_flag or self.hierachical_flag
z_dim.insert(0, self.args.num_rows * self.args.num_rows)
num_rows = self.args.num_rows
code_x = torch.linspace(-2, 2, steps=num_rows).view(1, num_rows).repeat(num_rows, 1)
code_y = code_x.t()
if self.args.ngpus > 0:
self.z = torch.cuda.FloatTensor(*z_dim).normal_()
self.code = torch.stack([code_x, code_y], dim=2).view(-1,2).cuda()
else:
self.z = torch.FloatTensor(*z_dim).normal_()
self.code = torch.stack([code_x, code_y], dim=2).view(-1,2)
评论列表
文章目录