def loss(self,x):
self.forward(x)
criterion = nn.BCELoss()
x_recons = self.sigmoid(self.cs[-1])
Lx = criterion(x_recons,x) * self.A * self.B
Lz = 0
kl_terms = [0] * T
for t in xrange(self.T):
mu_2 = self.mus[t] * self.mus[t]
sigma_2 = self.sigmas[t] * self.sigmas[t]
logsigma = self.logsigmas[t]
# Lz += (0.5 * (mu_2 + sigma_2 - 2 * logsigma)) # 11
kl_terms[t] = 0.5 * torch.sum(mu_2+sigma_2-2 * logsigma,1) - self.T * 0.5
Lz += kl_terms[t]
# Lz -= self.T / 2
Lz = torch.mean(Lz) ####################################################
loss = Lz + Lx # 12
return loss
# correct
评论列表
文章目录