models.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:CausalGAN 作者: mkocaoglu 项目源码 文件源码
def DiscriminatorCNN(image, config, reuse=None):
    '''
    Discriminator for GAN model.

    image      : batch_size x 64x64x3 image
    config     : see causal_dcgan/config.py
    reuse      : pass True if not calling for first time

    returns: probabilities(real)
           : logits(real)
           : first layer activation used to estimate z from
           : variables list
    '''
    with tf.variable_scope("discriminator",reuse=reuse) as vs:
        d_bn1 = batch_norm(name='d_bn1')
        d_bn2 = batch_norm(name='d_bn2')
        d_bn3 = batch_norm(name='d_bn3')

        if not config.stab_proj:
            h0 = lrelu(conv2d(image, config.df_dim, name='d_h0_conv'))#16,32,32,64

        else:#method to restrict disc from winning
            #I think this is equivalent to just not letting disc optimize first layer
            #and also removing nonlinearity

            #k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02,
            #paper used 8x8 kernel, but I'm using 5x5 because it is more similar to my achitecture
            #n_projs=config.df_dim#64 instead of 32 in paper
            n_projs=config.n_stab_proj#64 instead of 32 in paper

            print("WARNING:STAB_PROJ active, using ",n_projs," projections")

            w_proj = tf.get_variable('w_proj', [5, 5, image.get_shape()[-1],n_projs],
                initializer=tf.truncated_normal_initializer(stddev=0.02),trainable=False)
            conv = tf.nn.conv2d(image, w_proj, strides=[1, 2, 2, 1], padding='SAME')

            b_proj = tf.get_variable('b_proj', [n_projs],#does nothing
                 initializer=tf.constant_initializer(0.0),trainable=False)
            h0=tf.nn.bias_add(conv,b_proj)


        h1_ = lrelu(d_bn1(conv2d(h0, config.df_dim*2, name='d_h1_conv')))#16,16,16,128

        h1 = add_minibatch_features(h1_, config.df_dim)
        h2 = lrelu(d_bn2(conv2d(h1, config.df_dim*4, name='d_h2_conv')))#16,16,16,248
        h3 = lrelu(d_bn3(conv2d(h2, config.df_dim*8, name='d_h3_conv')))
        #print('h3shape: ',h3.get_shape().as_list())
        #print('8df_dim:',config.df_dim*8)
        #dim3=tf.reduce_prod(tf.shape(h3)[1:])
        dim3=np.prod(h3.get_shape().as_list()[1:])
        h3_flat=tf.reshape(h3, [-1,dim3])
        h4 = linear(h3_flat, 1, 'd_h3_lin')

        prob=tf.nn.sigmoid(h4)

        variables = tf.contrib.framework.get_variables(vs,collection=tf.GraphKeys.TRAINABLE_VARIABLES)

    return prob, h4, h1_, variables
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号