def compute_loss_and_gradient(self, x):
self.optimizer.zero_grad()
recon_x, z_mean, z_var = self.model_eval(x)
binary_cross_entropy = functional.binary_cross_entropy(recon_x, x.view(-1, 784))
# Uses analytical KL divergence expression for D_kl(q(z|x) || p(z))
# Refer to Appendix B from VAE paper:
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
# (https://arxiv.org/abs/1312.6114)
kl_div = -0.5 * torch.sum(1 + z_var.log() - z_mean.pow(2) - z_var)
kl_div /= self.args.batch_size * 784
loss = binary_cross_entropy + kl_div
if self.mode == TRAIN:
loss.backward()
self.optimizer.step()
return loss.data[0]
评论列表
文章目录