def dis_net(self, images, y, reuse=False):
with tf.variable_scope("discriminator") as scope:
if reuse == True:
scope.reuse_variables()
# mnist data's shape is (28 , 28 , 1)
yb = tf.reshape(y, shape=[self.batch_size, 1, 1, self.y_dim])
# concat
concat_data = conv_cond_concat(images, yb)
conv1, w1 = conv2d(concat_data, output_dim=10, name='dis_conv1')
tf.add_to_collection('weight_1', w1)
conv1 = lrelu(conv1)
conv1 = conv_cond_concat(conv1, yb)
tf.add_to_collection('ac_1', conv1)
conv2, w2 = conv2d(conv1, output_dim=64, name='dis_conv2')
tf.add_to_collection('weight_2', w2)
conv2 = lrelu(batch_normal(conv2, scope='dis_bn1'))
tf.add_to_collection('ac_2', conv2)
conv2 = tf.reshape(conv2, [self.batch_size, -1])
conv2 = tf.concat([conv2, y], 1)
f1 = lrelu(batch_normal(fully_connect(conv2, output_size=1024, scope='dis_fully1'), scope='dis_bn2', reuse=reuse))
f1 = tf.concat([f1, y], 1)
out = fully_connect(f1, output_size=1, scope='dis_fully2')
return tf.nn.sigmoid(out), out
评论列表
文章目录