gen.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:Keras-GAN 作者: Shaofanl 项目源码 文件源码
def basic_gen(input_shape, img_shape, nf=128, scale=4, FC=[], use_upsample=False):
    dim, h, w = img_shape 

    img = Input(input_shape)
    x = img
    for fc_dim in FC: 
        x = Dense(fc_dim)(x)
        x = BatchNormalization()(x)
        x = Activation('relu')(x)

    x = Dense(nf*2**(scale-1)*(h/2**scale)*(w/2**scale))(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Reshape((nf*2**(scale-1), h/2**scale, w/2**scale))(x)

    for s in range(scale-2, -1, -1):
        # up sample can elimiate the checkbroad artifact
        # http://distill.pub/2016/deconv-checkerboard/
        if use_upsample:
            x = UpSampling2D()(x)
            x = Conv2D(nf*2**s, (3,3), padding='same')(x)
        else:
            x = Deconv2D(nf*2**s, (3, 3), strides=(2, 2), padding='same')(x) 
        x = BatchNormalization()(x)
        x = Activation('relu')(x)

    if use_upsample:
        x = UpSampling2D()(x)
        x = Conv2D(dim, (3, 3), padding='same')(x)
    else:
        x = Deconv2D(dim, (3, 3), strides=(2, 2), padding='same')(x) 

    x = Activation('tanh')(x)

    return Model(img, x)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号