def decoder(self, z, embedding, reuse=None):
with tf.variable_scope("decoder", reuse=reuse):
with slim.arg_scope([slim.conv2d, slim.fully_connected],
weights_initializer=tf.contrib.layers.variance_scaling_initializer(),
weights_regularizer=slim.l2_regularizer(5e-4),
bias_initializer=tf.zeros_initializer()):
with slim.arg_scope([slim.conv2d], padding="SAME",
activation_fn=tf.nn.elu, stride=1):
x = slim.fully_connected(z, 8 * 8 * embedding, activation_fn=None)
x = tf.reshape(x, [-1, 8, 8, embedding])
for i in range(self.conv_repeat_num):
x = slim.repeat(x, 2, slim.conv2d, embedding, 3)
if i < self.conv_repeat_num - 1:
x = resize_nn(x, 2) # NN up-sampling
x = slim.conv2d(x, 3, 3, activation_fn=None)
return x
评论列表
文章目录