def recode_cost(self, inputs, variation, eps=1e-5, **kwargs):
"""
Cost for given input batch of samples, under current params.
"""
h = self.get_h_inputs(inputs)
z_mu = tf.matmul(h, self.params['Mhz']) + self.params['bMhz']
z_sig = tf.matmul(h, self.params['Shz']) + self.params['bShz']
# KL divergence between latent space induced by encoder and ...
lat_loss = -tf.reduce_sum(1 + z_sig - z_mu**2 - tf.exp(z_sig), 1)
z = z_mu + tf.sqrt(tf.exp(z_sig)) * variation
h = self.get_h_latents(z)
x_mu = self.decoding(tf.matmul(h, self.params['Mhx']) + self.params['bMhx'])
x_sig = self.decoding(tf.matmul(h, self.params['Shx']) + self.params['bShx'])
# x_sig = tf.clip_by_value(x_mu * (1 - x_mu), .05, 1)
# decoding likelihood term
like_loss = tf.reduce_sum(tf.log(x_sig + eps) +
(inputs - x_mu)**2 / x_sig, 1)
# # Mean cross entropy between input and encode-decoded input.
# like_loss = 2 * tf.reduce_sum(functions.cross_entropy(inputs, x_mu), 1)
return .5 * tf.reduce_mean(like_loss + lat_loss)
评论列表
文章目录