models.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:CausalGAN 作者: mkocaoglu 项目源码 文件源码
def GeneratorCNN( z, config, reuse=None):
    '''
    maps z to a 64x64 images with values in [-1,1]
    uses batch normalization internally
    '''

    #trying to get around batch_size like this:
    batch_size=tf.shape(z)[0]
    #batch_size=tf.placeholder_with_default(64,[],'bs')

    with tf.variable_scope("generator",reuse=reuse) as vs:
        g_bn0 = batch_norm(name='g_bn0')
        g_bn1 = batch_norm(name='g_bn1')
        g_bn2 = batch_norm(name='g_bn2')
        g_bn3 = batch_norm(name='g_bn3')

        s_h, s_w = config.gf_dim, config.gf_dim#64,64
        s_h2, s_w2 = conv_out_size_same(s_h, 2), conv_out_size_same(s_w, 2)
        s_h4, s_w4 = conv_out_size_same(s_h2, 2), conv_out_size_same(s_w2, 2)
        s_h8, s_w8 = conv_out_size_same(s_h4, 2), conv_out_size_same(s_w4, 2)
        s_h16, s_w16 = conv_out_size_same(s_h8, 2), conv_out_size_same(s_w8, 2)



        # project `z` and reshape
        z_, self_h0_w, self_h0_b = linear(
            z, config.gf_dim*8*s_h16*s_w16, 'g_h0_lin', with_w=True)

        self_h0 = tf.reshape(
            z_, [-1, s_h16, s_w16, config.gf_dim * 8])
        h0 = tf.nn.relu(g_bn0(self_h0))

        h1, h1_w, h1_b = deconv2d(
            h0, [batch_size, s_h8, s_w8, config.gf_dim*4], name='g_h1', with_w=True)
        h1 = tf.nn.relu(g_bn1(h1))

        h2, h2_w, h2_b = deconv2d(
            h1, [batch_size, s_h4, s_w4, config.gf_dim*2], name='g_h2', with_w=True)
        h2 = tf.nn.relu(g_bn2(h2))

        h3, h3_w, h3_b = deconv2d(
            h2, [batch_size, s_h2, s_w2, config.gf_dim*1], name='g_h3', with_w=True)
        h3 = tf.nn.relu(g_bn3(h3))

        h4, h4_w, h4_b = deconv2d(
            h3, [batch_size, s_h, s_w, config.c_dim], name='g_h4', with_w=True)
        out=tf.nn.tanh(h4)

    variables = tf.contrib.framework.get_variables(vs)
    return out, variables
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号