def discriminate(self, image, Y):
print("Initializing the discriminator")
print("Y shape", Y.get_shape())
yb = tf.reshape(Y, tf.stack([self.batch_size, 1, 1, self.dim_y]))
print("image shape", image.get_shape())
print("yb shape", yb.get_shape())
X = tf.concat([image, yb * tf.ones([self.batch_size, 24, 24, self.dim_y])],3)
print("X shape", X.get_shape())
h1 = lrelu( tf.nn.conv2d( X, self.discrim_W1, strides=[1,2,2,1], padding='SAME' ))
print("h1 shape", h1.get_shape())
h1 = tf.concat([h1, yb * tf.ones([self.batch_size, 12, 12, self.dim_y])],3)
print("h1 shape", h1.get_shape())
h2 = lrelu(batchnormalize( tf.nn.conv2d( h1, self.discrim_W2, strides=[1,2,2,1], padding='SAME')) )
print("h2 shape", h2.get_shape())
h2 = tf.reshape(h2, [self.batch_size, -1])
h2 = tf.concat([h2, Y], 1)
discri=tf.matmul(h2, self.discrim_W3 )
print("discri shape", discri.get_shape())
h3 = lrelu(batchnormalize(discri))
return h3
评论列表
文章目录