def train(self, x, L=1, test=False):
batchsize = x.data.shape[0]
z_mean, z_ln_var = self.encoder(x, test=test, apply_f=False)
loss = 0
for l in xrange(L):
# Sample z
z = F.gaussian(z_mean, z_ln_var)
# Decode
x_expectation = self.decoder(z, test=test, apply_f=False)
# E_q(z|x)[log(p(x|z))]
loss += self.bernoulli_nll_keepbatch(x, x_expectation)
if L > 1:
loss /= L
# KL divergence
loss += self.gaussian_kl_divergence_keepbatch(z_mean, z_ln_var)
loss = F.sum(loss) / batchsize
self.zero_grads()
loss.backward()
self.update()
if self.gpu:
loss.to_cpu()
return loss.data
评论列表
文章目录