def __init__(self, latent_dim, data_dim, network_architecture='synthetic'):
"""
Args:
latent_dim: int, the flattened dimensionality of the latent space
data_dim: int, the flattened dimensionality of the output space (data space)
network_architecture: str, the architecture name for the body of the Decoder model
"""
super(Decoder, self).__init__(latent_dim=latent_dim, data_dim=data_dim,
network_architecture=network_architecture,
name='Standard Decoder')
generator_body = get_network_by_name['decoder'][network_architecture](self.latent_input)
# NOTE: all decoder layers have names prefixed by `dec`.
# This is essential for the partial model freezing during training.
sampler_params = Dense(self.data_dim, activation='sigmoid', name='dec_sampler_params')(generator_body)
# a probability clipping is necessary for the Bernoulli `log_prob` property produces NaNs in the border cases.
sampler_params = Lambda(lambda x: 1e-6 + (1 - 2e-6) * x, name='dec_probs_clipper')(sampler_params)
def bernoulli_log_probs(args):
from tensorflow.contrib.distributions import Bernoulli
mu, x = args
log_px = Bernoulli(probs=mu, name='dec_bernoulli').log_prob(x)
return log_px
log_probs = Lambda(bernoulli_log_probs, name='dec_bernoulli_logprob')([sampler_params, self.data_input])
self.generator = Model(inputs=self.latent_input, outputs=sampler_params, name='dec_sampling')
self.ll_estimator = Model(inputs=[self.data_input, self.latent_input], outputs=log_probs, name='dec_trainable')
评论列表
文章目录