VAEGAN.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:DisentangleVAE 作者: Jueast 项目源码 文件源码
def GAN_loss(self, x):
        x = x.view(x.size(0), -1)
        if isinstance(x, torch.cuda.FloatTensor):
            eps = torch.cuda.FloatTensor(x.size(0), self.nz).normal_()
        else:
            eps = torch.FloatTensor(x.size(0), self.nz).normal_()
        alpha = torch.FloatTensor(x.size(0), 1).uniform_(0,1)
        alpha = alpha.expand(x.size(0), x.size(1))
        recon_pz = self.decode(Variable(eps))
        interpolates = alpha * x.data + (1-alpha) * recon_pz.data
        interpolates = Variable(interpolates, requires_grad=True)
        D_interpolates = self.D(interpolates)
        gradients = grad(D_interpolates, interpolates,create_graph=True)[0]
        slopes = torch.sum(gradients ** 2, 1).sqrt()
        gradient_penalty = (torch.mean(slopes - 1.) ** 2)
        return self.D(x) - self.D(recon_pz) - 10 * gradient_penalty
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号