model.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:dgm 作者: ashwindcruz 项目源码 文件源码
def __call__(self, x):
        # Compute q(z|x)
        # pdb.set_trace()
        encoding_time = time.time()
        self.encode(x)
        encoding_time = float(time.time() - encoding_time)

        decoding_time_average = 0.

        self.kl = 0
        self.logp = 0
        for j in xrange(self.num_zsamples):
            # z ~ q(z|x)
            z = F.gaussian(self.qmu, self.qln_var)
            # pdb.set_trace()
            # Compute log q(z|x)
            encoder_log = gaussian_logp(z, self.qmu, self.qln_var)

            # Compute p(x|z)
            decoding_time = time.time()
            self.decode(z)
            decoding_time = time.time() - decoding_time
            decoding_time_average += decoding_time

            # Computer p(z)
            prior_log = gaussian_logp0(z)

            # Compute objective
            self.kl += (encoder_log-prior_log)
            self.logp += bernoulli_logp(x, self.p_ber_prob_logit)
            # pdb.set_trace()

        current_temperature = min(self.temperature['value'],1.0)
        self.temperature['value'] += self.temperature['increment']

        decoding_time_average /= self.num_zsamples
        self.logp /= self.num_zsamples
        self.kl /= self.num_zsamples
        # pdb.set_trace()
        self.obj_batch = self.logp - (current_temperature*self.kl)
        self.timing_info = np.array([encoding_time,decoding_time_average])

        batch_size = self.obj_batch.shape[0]

        self.obj = -F.sum(self.obj_batch)/batch_size

        # pdb.set_trace()

        return self.obj
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号