networks.py 文件源码

python
阅读 37 收藏 0 点赞 0 评论 0

项目:comprehend 作者: Fenugreek 项目源码 文件源码
def recode_cost(self, inputs, variation, eps=1e-5, **kwargs):
        """
        Cost for given input batch of samples, under current params.
        """
        h = self.get_h_inputs(inputs)
        z_mu = tf.matmul(h, self.params['Mhz']) + self.params['bMhz']
        z_sig = tf.matmul(h, self.params['Shz']) + self.params['bShz']

        # KL divergence between latent space induced by encoder and ...
        lat_loss = -tf.reduce_sum(1 + z_sig - z_mu**2 - tf.exp(z_sig), 1)

        z = z_mu + tf.sqrt(tf.exp(z_sig)) * variation
        h = self.get_h_latents(z)
        x_mu = self.decoding(tf.matmul(h, self.params['Mhx']) + self.params['bMhx'])
        x_sig = self.decoding(tf.matmul(h, self.params['Shx']) + self.params['bShx'])
#        x_sig = tf.clip_by_value(x_mu * (1 - x_mu), .05, 1)

        # decoding likelihood term
        like_loss = tf.reduce_sum(tf.log(x_sig + eps) +
                                  (inputs - x_mu)**2 / x_sig, 1)

#        # Mean cross entropy between input and encode-decoded input.
#        like_loss = 2 * tf.reduce_sum(functions.cross_entropy(inputs, x_mu), 1)

        return .5 * tf.reduce_mean(like_loss + lat_loss)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号