vae.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:vae-npvc 作者: JeremyCCHsu 项目源码 文件源码
def loss(self, x, y):
        with tf.name_scope('loss'):
            z_mu, z_lv = self._encode(x)
            z = GaussianSampleLayer(z_mu, z_lv)
            xh = self._generate(z, y)

            D_KL = tf.reduce_mean(
                GaussianKLD(
                    slim.flatten(z_mu),
                    slim.flatten(z_lv),
                    slim.flatten(tf.zeros_like(z_mu)),
                    slim.flatten(tf.zeros_like(z_lv)),
                )
            )
            logPx = tf.reduce_mean(
                GaussianLogDensity(
                    slim.flatten(x),
                    slim.flatten(xh),
                    tf.zeros_like(slim.flatten(xh))),
            )

        loss = dict()
        loss['G'] = - logPx + D_KL
        loss['D_KL'] = D_KL
        loss['logP'] = logPx

        tf.summary.scalar('KL-div', D_KL)
        tf.summary.scalar('logPx', logPx)

        tf.summary.histogram('xh', xh)
        tf.summary.histogram('x', x)
        return loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号