def compute_loss(x_dec, x_next_pred_dec, x, x_next,
Qz, Qz_next_pred,
Qz_next):
# Reconstruction losses
if False:
x_reconst_loss = (x_dec - x_next).pow(2).sum(dim=1)
x_next_reconst_loss = (x_next_pred_dec - x_next).pow(2).sum(dim=1)
else:
x_reconst_loss = -binary_crossentropy(x, x_dec).sum(dim=1)
x_next_reconst_loss = -binary_crossentropy(x_next, x_next_pred_dec).sum(dim=1)
logvar = Qz.logsigma.mul(2)
KLD_element = Qz.mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
KLD = torch.sum(KLD_element, dim=1).mul(-0.5)
# ELBO
bound_loss = x_reconst_loss.add(x_next_reconst_loss).add(KLD)
kl = KLDGaussian(Qz_next_pred, Qz_next)
return bound_loss.mean(), kl.mean()
评论列表
文章目录