def discriminator(input_images, reuse=False):
with slim.arg_scope([slim.batch_norm],
is_training=train_phase, reuse=reuse, decay=0.9, epsilon=1e-5,
param_initializers={
"beta": tf.constant_initializer(value=0),
"gamma": tf.random_normal_initializer(mean=1, stddev=0.045)
}):
with slim.arg_scope([slim.conv2d, slim.conv2d_transpose],
weights_initializer=tf.truncated_normal_initializer(stddev=0.045),
biases_initializer=tf.constant_initializer(value=0),
activation_fn=None, reuse=reuse):
# Encoder
out_1 = slim.conv2d(inputs=input_images,
num_outputs=32,
kernel_size=[4, 4],
stride=2,
padding='SAME',
scope="Discriminator/conv_1")
bn_1 = slim.batch_norm(inputs=out_1, scope="Discriminator/bn_1")
out_1 = tf.maximum(0.2 * bn_1, bn_1, 'Discriminator/leaky_relu_1')
out_2 = slim.conv2d(inputs=out_1,
num_outputs=64,
kernel_size=[4, 4],
padding='SAME',
stride=2,
scope="Discriminator/conv_2")
bn_2 = slim.batch_norm(inputs=out_2, scope="Discriminator/bn_2")
out_2 = tf.maximum(0.2 * bn_2, bn_2, 'Discriminator/leaky_relu_2')
out_3 = slim.conv2d(inputs=out_2,
num_outputs=128,
kernel_size=[4, 4],
padding='SAME',
stride=2,
scope="Discriminator/conv_3")
bn_3 = slim.batch_norm(inputs=out_3, scope="Discriminator/bn_3")
out_3 = tf.maximum(0.2 * bn_3, bn_3, 'Discriminator/leaky_relu_3')
encode = tf.reshape(out_3, [-1, 2 * IMAGE_SIZE * IMAGE_SIZE], name="Discriminator/encode")
# Decoder
out_3 = tf.reshape(encode, [-1, IMAGE_SIZE // 8, IMAGE_SIZE // 8, 128], name="Discriminator/encode_reshape")
out_4 = slim.conv2d_transpose(inputs=out_3, num_outputs=64, kernel_size=[4, 4], stride=2,
padding='SAME', scope="Discriminator/deconv_4")
out_4 = slim.batch_norm(out_4, scope="Discriminator/bn_4")
out_4 = tf.maximum(0.2 * out_4, out_4, name="Discriminator/leaky_relu_4")
out_5 = slim.conv2d_transpose(inputs=out_4, num_outputs=32, kernel_size=[4, 4], stride=2,
padding='SAME', scope="Discriminator/deconv_5" )
out_5 = slim.batch_norm(out_5, scope="Discriminator/bn_5")
out_5 = tf.maximum(0.2 * out_5, out_5, name="Discriminator/leaky_relu_5")
out_6 = slim.conv2d_transpose(inputs=out_5, num_outputs=3, kernel_size=[4, 4], stride=2,
padding='SAME', scope="Discriminator/deconv_6")
# out_6 = slim.batch_norm(out_6, scope="Discriminator/bn_6")
decoded = tf.nn.tanh(out_6, name="Discriminator/tanh_6")
return encode, decoded
# mean squared errors
评论列表
文章目录