model.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:dgm 作者: ashwindcruz 项目源码 文件源码
def __call__(self, x):
        # Obtain parameters for q(z|x)
        encoding_time = time.time()
        qmu, qln_var, qh_vec_0 = self.encode(x)
        encoding_time = float(time.time() - encoding_time)

        decoding_time_average = 0.

        self.kl = 0
        self.logp = 0
        for j in xrange(self.num_zsamples):
            # z_0 ~ q(z|x)
            z_0 = F.gaussian(qmu, qln_var)

            # Perform Householder flow transformation, Equation (8)
            decoding_time = time.time()
            z_T = self.house_transform(z_0)

            # Obtain parameters for p(x|z_T)
            p_ber_prob_logit = self.decode(z_T)
            decoding_time = time.time() - decoding_time
            decoding_time_average += decoding_time

            # Compute objective
            self.logp += bernoulli_logp(x, self.p_ber_prob_logit)
            self.kl += gaussian_kl_divergence(z_0, qmu, qln_var, z_T)


        decoding_time_average /= self.num_zsamples

        self.logp /= self.num_zsamples
        self.kl /= self.num_zsamples

        current_temperature = min(self.temperature['value'],1.0)
        self.obj_batch = self.logp - (current_temperature*self.kl)
        self.temperature['value'] += self.temperature['increment']


        self.timing_info = np.array([encoding_time,decoding_time_average])

        batch_size = self.obj_batch.shape[0]

        self.obj = -F.sum(self.obj_batch)/batch_size

        return self.obj
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号