model.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:dgm 作者: ashwindcruz 项目源码 文件源码
def __call__(self, x):
        # Obtain parameters for q(z|x)
        encoding_time = time.time()
        self.encode(x)
        encoding_time = float(time.time() - encoding_time)

        decoding_time_average = 0.

        xp = cuda.cupy
        self.importance_weights = 0
        self.w_holder = []
        self.kl = 0
        self.logp = 0

        for j in xrange(self.num_zsamples):
            # Sample z ~ q(z|x)
            z = F.gaussian(self.qmu, self.qln_var)

            # Compute log q(z|x)
            encoder_log = gaussian_logp(z, self.qmu, self.qln_var)

            # Obtain parameters for p(x|z)
            decoding_time = time.time()
            self.decode(z)
            decoding_time = time.time() - decoding_time
            decoding_time_average += decoding_time

            # Compute log p(x|z)
            decoder_log = bernoulli_logp(x, self.p_ber_prob_logit)

            # Compute log p(z). 
            prior_log = gaussian_logp0(z)

            # Store the latest log weight'
            current_temperature = min(self.temperature['value'],1.0)
            self.w_holder.append(decoder_log + current_temperature*(prior_log - encoder_log))

            # Store the KL and Logp equivalents. They are not used for computation but for recording and reporting. 
            self.kl += (encoder_log-prior_log)
            self.logp += (decoder_log)

        self.temperature['value'] += self.temperature['increment']        

        # Compute w' for this sample (batch)
        logps = F.stack(self.w_holder)
        self.obj_batch = F.logsumexp(logps, axis=0) - np.log(self.num_zsamples)
        self.kl /= self.num_zsamples
        self.logp /= self.num_zsamples

        decoding_time_average /= self.num_zsamples

        batch_size = self.obj_batch.shape[0]

        self.obj = -F.sum(self.obj_batch)/batch_size        
        self.timing_info = np.array([encoding_time,decoding_time_average])

        return self.obj
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号