decoder.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:adversarial-variational-bayes 作者: gdikov 项目源码 文件源码
def __init__(self, latent_dim, data_dim, network_architecture='synthetic'):
        """
        Args:
            latent_dim: int, the flattened dimensionality of the latent space 
            data_dim: int, the flattened dimensionality of the output space (data space)
            network_architecture: str, the architecture name for the body of the Decoder model
        """
        super(Decoder, self).__init__(latent_dim=latent_dim, data_dim=data_dim,
                                      network_architecture=network_architecture,
                                      name='Standard Decoder')

        generator_body = get_network_by_name['decoder'][network_architecture](self.latent_input)

        # NOTE: all decoder layers have names prefixed by `dec`.
        # This is essential for the partial model freezing during training.
        sampler_params = Dense(self.data_dim, activation='sigmoid', name='dec_sampler_params')(generator_body)

        # a probability clipping is necessary for the Bernoulli `log_prob` property produces NaNs in the border cases.
        sampler_params = Lambda(lambda x: 1e-6 + (1 - 2e-6) * x, name='dec_probs_clipper')(sampler_params)

        def bernoulli_log_probs(args):
            from tensorflow.contrib.distributions import Bernoulli
            mu, x = args
            log_px = Bernoulli(probs=mu, name='dec_bernoulli').log_prob(x)
            return log_px

        log_probs = Lambda(bernoulli_log_probs, name='dec_bernoulli_logprob')([sampler_params, self.data_input])

        self.generator = Model(inputs=self.latent_input, outputs=sampler_params, name='dec_sampling')
        self.ll_estimator = Model(inputs=[self.data_input, self.latent_input], outputs=log_probs, name='dec_trainable')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号