visualizer.py 文件源码

python
阅读 40 收藏 0 点赞 0 评论 0

项目:DisentangleVAE 作者: Jueast 项目源码 文件源码
def __init__(self, savefolder, imgdim, args, network):
        super(ManifoldVisualizer, self).__init__(savefolder, imgdim, args)
        self.network = network
        self.name = "manifold"
        self.parts = args.parts

        z_dim = [int(np.prod(self.network.code_dims))]
        self.flat_flag = z_dim[0] >= 2 * self.parts
        self.hierachical_flag = z_dim[0] >= self.parts and len(z_dim) > 1 and z_dim[1] >= 2 
        assert self.flat_flag or self.hierachical_flag
        z_dim.insert(0, self.args.num_rows * self.args.num_rows)

        num_rows = self.args.num_rows
        code_x = torch.linspace(-2, 2, steps=num_rows).view(1, num_rows).repeat(num_rows, 1)
        code_y = code_x.t()
        if self.args.ngpus > 0:
            self.z = torch.cuda.FloatTensor(*z_dim).normal_()
            self.code = torch.stack([code_x, code_y], dim=2).view(-1,2).cuda()
        else:
            self.z = torch.FloatTensor(*z_dim).normal_()
            self.code = torch.stack([code_x, code_y], dim=2).view(-1,2)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号