def decoder(layers, y, BATCH_SIZE):
layers = layers[0]
# get all the layers from the encoder for skip connections.
enc_conv1 = layers[0]
enc_conv2 = layers[1]
enc_conv3 = layers[2]
enc_conv4 = layers[3]
enc_conv5 = layers[4]
enc_conv6 = layers[5]
'''
print 'enc_conv1:',enc_conv1
print 'enc_conv2:',enc_conv2
print 'enc_conv3:',enc_conv3
print 'enc_conv4:',enc_conv4
print 'enc_conv5:',enc_conv5
print 'enc_conv6:',enc_conv6
'''
# z is the latent encoding (conv6)
z = tcl.flatten(layers[-1])
z = tf.concat([z,y], axis=1)
print 'z:',z
# reshape z to put through transpose convolutions
s = z.get_shape().as_list()[-1]
z = tf.reshape(z, [BATCH_SIZE, 1, 1, s])
print 'z:',z
dec_conv1 = tcl.convolution2d_transpose(z, 512, 4, 2, activation_fn=tf.nn.relu, weights_initializer=tf.random_normal_initializer(stddev=0.02), scope='g_dec_conv1')
#dec_conv1 = tf.concat([dec_conv1, enc_conv5], axis=3)
print 'dec_conv1:',dec_conv1
dec_conv2 = tcl.convolution2d_transpose(dec_conv1, 512, 4, 2, activation_fn=tf.nn.relu, weights_initializer=tf.random_normal_initializer(stddev=0.02), scope='g_dec_conv2')
#dec_conv2 = tf.concat([dec_conv2, enc_conv4], axis=3)
print 'dec_conv2:',dec_conv2
dec_conv3 = tcl.convolution2d_transpose(dec_conv2, 256, 4, 2, activation_fn=tf.nn.relu, weights_initializer=tf.random_normal_initializer(stddev=0.02), scope='g_dec_conv3')
#dec_conv3 = tf.concat([dec_conv3, enc_conv3], axis=3)
print 'dec_conv3:',dec_conv3
dec_conv4 = tcl.convolution2d_transpose(dec_conv3, 128, 4, 2, activation_fn=tf.nn.relu, weights_initializer=tf.random_normal_initializer(stddev=0.02), scope='g_dec_conv4')
#dec_conv3 = tf.concat([dec_conv4, enc_conv2], axis=3)
print 'dec_conv4:',dec_conv4
dec_conv5 = tcl.convolution2d_transpose(dec_conv4, 64, 4, 2, activation_fn=tf.nn.relu, weights_initializer=tf.random_normal_initializer(stddev=0.02), scope='g_dec_conv5')
#dec_conv3 = tf.concat([dec_conv5, enc_conv1], axis=3)
print 'dec_conv5:',dec_conv5
dec_conv6 = tcl.convolution2d_transpose(dec_conv5, 3, 4, 2, activation_fn=tf.nn.relu, weights_initializer=tf.random_normal_initializer(stddev=0.02), scope='g_dec_conv6')
print 'dec_conv6:',dec_conv6
print
print 'END G'
print
return dec_conv6
评论列表
文章目录