def GAN_loss(self, x):
x = x.view(x.size(0), -1)
if isinstance(x, torch.cuda.FloatTensor):
eps = torch.cuda.FloatTensor(x.size(0), self.nz).normal_()
else:
eps = torch.FloatTensor(x.size(0), self.nz).normal_()
alpha = torch.FloatTensor(x.size(0), 1).uniform_(0,1)
alpha = alpha.expand(x.size(0), x.size(1))
recon_pz = self.decode(Variable(eps))
interpolates = alpha * x.data + (1-alpha) * recon_pz.data
interpolates = Variable(interpolates, requires_grad=True)
D_interpolates = self.D(interpolates)
gradients = grad(D_interpolates, interpolates,create_graph=True)[0]
slopes = torch.sum(gradients ** 2, 1).sqrt()
gradient_penalty = (torch.mean(slopes - 1.) ** 2)
return self.D(x) - self.D(recon_pz) - 10 * gradient_penalty
评论列表
文章目录