model_def.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:csgm 作者: AshishBora 项目源码 文件源码
def generator(hparams, z, train, reuse):

    if reuse:
        tf.get_variable_scope().reuse_variables()

    output_size = 64
    s = output_size
    s2, s4, s8, s16 = int(s/2), int(s/4), int(s/8), int(s/16)

    g_bn0 = ops.batch_norm(name='g_bn0')
    g_bn1 = ops.batch_norm(name='g_bn1')
    g_bn2 = ops.batch_norm(name='g_bn2')
    g_bn3 = ops.batch_norm(name='g_bn3')

    # project `z` and reshape
    h0 = tf.reshape(ops.linear(z, hparams.gf_dim*8*s16*s16, 'g_h0_lin'), [-1, s16, s16, hparams.gf_dim * 8])
    h0 = tf.nn.relu(g_bn0(h0, train=train))

    h1 = ops.deconv2d(h0, [hparams.batch_size, s8, s8, hparams.gf_dim*4], name='g_h1')
    h1 = tf.nn.relu(g_bn1(h1, train=train))

    h2 = ops.deconv2d(h1, [hparams.batch_size, s4, s4, hparams.gf_dim*2], name='g_h2')
    h2 = tf.nn.relu(g_bn2(h2, train=train))

    h3 = ops.deconv2d(h2, [hparams.batch_size, s2, s2, hparams.gf_dim*1], name='g_h3')
    h3 = tf.nn.relu(g_bn3(h3, train=train))

    h4 = ops.deconv2d(h3, [hparams.batch_size, s, s, hparams.c_dim], name='g_h4')
    x_gen = tf.nn.tanh(h4)

    return x_gen
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号