def get_loss_func(self, C=1.0, k=1):
def lf(x):
mu, ln_var = self.encode(x)
batchsize = len(mu.data)
rec_loss = 0
for l in six.moves.range(k):
z = F.gaussian(mu, ln_var)
rec_loss += F.bernoulli_nill(x, self.decode(z, sigmoid=False))
rec_loss /= (k * batchsize)
self.rec_loss = rec_loss
self.loss = self.rec_loss + C * gaussian_kl_divergence(mu, ln_var)
self.loss /= batchsize
return self.loss
return lf
评论列表
文章目录