def _decoder(self, z):
""" Define p(x|z) network"""
if z is None:
mean = None
stddev = None
input_sample = self.epsilon
else:
z = tf.reshape(z, [-1, self.flags['hidden_size'] * 2])
print(z.get_shape())
mean, stddev = tf.split(1, 2, z)
stddev = tf.sqrt(tf.exp(stddev))
input_sample = mean + self.epsilon * stddev
decoder = Layers(tf.expand_dims(tf.expand_dims(input_sample, 1), 1))
decoder.deconv2d(3, 128, padding='VALID')
decoder.deconv2d(3, 128, padding='VALID', stride=2)
decoder.deconv2d(3, 64, stride=2)
decoder.deconv2d(3, 64, stride=2)
decoder.deconv2d(5, 1, activation_fn=tf.nn.tanh, s_value=None)
return decoder.get_output(), mean, stddev
评论列表
文章目录