def __call__(self, x):
# Obtain parameters for q(z|x)
encoding_time = time.time()
self.encode(x)
encoding_time = float(time.time() - encoding_time)
decoding_time_average = 0.
xp = cuda.cupy
self.importance_weights = 0
self.w_holder = []
self.kl = 0
self.logp = 0
for j in xrange(self.num_zsamples):
# Sample z ~ q(z|x)
z = F.gaussian(self.qmu, self.qln_var)
# Compute log q(z|x)
encoder_log = gaussian_logp(z, self.qmu, self.qln_var)
# Obtain parameters for p(x|z)
decoding_time = time.time()
self.decode(z)
decoding_time = time.time() - decoding_time
decoding_time_average += decoding_time
# Compute log p(x|z)
decoder_log = bernoulli_logp(x, self.p_ber_prob_logit)
# Compute log p(z).
prior_log = gaussian_logp0(z)
# Store the latest log weight'
current_temperature = min(self.temperature['value'],1.0)
self.w_holder.append(decoder_log + current_temperature*(prior_log - encoder_log))
# Store the KL and Logp equivalents. They are not used for computation but for recording and reporting.
self.kl += (encoder_log-prior_log)
self.logp += (decoder_log)
self.temperature['value'] += self.temperature['increment']
# Compute w' for this sample (batch)
logps = F.stack(self.w_holder)
self.obj_batch = F.logsumexp(logps, axis=0) - np.log(self.num_zsamples)
self.kl /= self.num_zsamples
self.logp /= self.num_zsamples
decoding_time_average /= self.num_zsamples
batch_size = self.obj_batch.shape[0]
self.obj = -F.sum(self.obj_batch)/batch_size
self.timing_info = np.array([encoding_time,decoding_time_average])
return self.obj
评论列表
文章目录