lifelong_vae.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:LifelongVAE 作者: jramapuram 项目源码 文件源码
def reparameterize(encoded, num_discrete, tau, hard=False,
                       rnd_sample=None, eps=1e-20):
        eshp = encoded.get_shape().as_list()
        print("encoded = ", eshp)
        num_normal = eshp[1] - num_discrete
        print 'num_normal = ', num_normal
        logits_normal = encoded[:, 0:num_normal]
        logits_gumbel = encoded[:, num_normal:eshp[1]]

        # we reparameterize using both the N(0, I) and the gumbel(0, 1)
        z_discrete, kl_discrete = gumbel_reparmeterization(logits_gumbel,
                                                           tau,
                                                           rnd_sample,
                                                           hard)
        z_n, kl_n = gaussian_reparmeterization(logits_normal)

        # merge and pad appropriately
        z = tf.concat([z_n, z_discrete], axis=1)

        return [slim.flatten(z),
                slim.flatten(z_n),
                slim.flatten(z_discrete),
                slim.flatten(tf.nn.softmax(logits_gumbel)),
                kl_n,
                kl_discrete]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号