model_mnist.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:Conditional-Gans 作者: zhangqianhui 项目源码 文件源码
def dis_net(self, images, y, reuse=False):

        with tf.variable_scope("discriminator") as scope:

            if reuse == True:
                scope.reuse_variables()

            # mnist data's shape is (28 , 28 , 1)
            yb = tf.reshape(y, shape=[self.batch_size, 1, 1, self.y_dim])
            # concat
            concat_data = conv_cond_concat(images, yb)

            conv1, w1 = conv2d(concat_data, output_dim=10, name='dis_conv1')
            tf.add_to_collection('weight_1', w1)

            conv1 = lrelu(conv1)
            conv1 = conv_cond_concat(conv1, yb)
            tf.add_to_collection('ac_1', conv1)


            conv2, w2 = conv2d(conv1, output_dim=64, name='dis_conv2')
            tf.add_to_collection('weight_2', w2)

            conv2 = lrelu(batch_normal(conv2, scope='dis_bn1'))
            tf.add_to_collection('ac_2', conv2)

            conv2 = tf.reshape(conv2, [self.batch_size, -1])
            conv2 = tf.concat([conv2, y], 1)

            f1 = lrelu(batch_normal(fully_connect(conv2, output_size=1024, scope='dis_fully1'), scope='dis_bn2', reuse=reuse))
            f1 = tf.concat([f1, y], 1)

            out = fully_connect(f1, output_size=1, scope='dis_fully2')

            return tf.nn.sigmoid(out), out
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号