def __call__(self, x):
# Compute q(z|x)
encoding_time = time.time()
self.encode(x)
encoding_time = float(time.time() - encoding_time)
decoding_time_average = 0.
self.kl = gaussian_kl_divergence_standard(self.qmu, self.qln_var)
self.logp = 0
for j in xrange(self.num_zsamples):
# z ~ q(z|x)
z = F.gaussian(self.qmu, self.qln_var)
# Compute p(x|z)
decoding_time = time.time()
self.decode(z)
decoding_time = time.time() - decoding_time
decoding_time_average += decoding_time
# Compute objective
self.logp += gaussian_logp(x, self.pmu, self.pln_var)
current_temperature = min(self.temperature['value'],1.0)
self.temperature['value'] += self.temperature['increment']
# pdb.set_trace()
decoding_time_average /= self.num_zsamples
self.logp /= self.num_zsamples
self.obj_batch = self.logp - (current_temperature*self.kl)
self.timing_info = np.array([encoding_time,decoding_time_average])
batch_size = self.obj_batch.shape[0]
self.obj = -F.sum(self.obj_batch)/batch_size
return self.obj
评论列表
文章目录