def _decoder(self, z):
"""Define p(x|z) network"""
if z is None:
mean = None
stddev = None
logits = None
class_predictions = None
input_sample = self.epsilon
else:
z = tf.reshape(z, [-1, self.flags['hidden_size'] * 2])
mean, stddev = tf.split(1, 2, z) # Compute latent variables (z) by calculating mean, stddev
stddev = tf.sqrt(tf.exp(stddev))
mlp = Layers(mean)
mlp.fc(self.flags['num_classes'])
class_predictions = mlp.get_output()
logits = tf.nn.softmax(class_predictions)
input_sample = mean + self.epsilon * stddev
decoder = Layers(tf.expand_dims(tf.expand_dims(input_sample, 1), 1))
decoder.deconv2d(3, 128, padding='VALID')
decoder.deconv2d(3, 64, padding='VALID', stride=2)
decoder.deconv2d(3, 64, stride=2)
decoder.deconv2d(5, 32, stride=2)
decoder.deconv2d(7, 1, activation_fn=tf.nn.tanh, s_value=None)
return decoder.get_output(), mean, stddev, class_predictions, logits
评论列表
文章目录