resnet.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:CycleGAN-keras 作者: Shaofanl 项目源码 文件源码
def resnet_6blocks(input_shape, output_nc, ngf, **kwargs):
    ks = 3
    f = 7
    p = (f-1)/2

    input = Input(input_shape)
    # local e1 = data - nn.SpatialReflectionPadding(p, p, p, p) - nn.SpatialConvolution(3, ngf, f, f, 1, 1) - normalization(ngf) - nn.ReLU(true)
    x = padding((p,p))(input)
    x = Conv2D(ngf, (f,f),)(x)
    x = normalize()(x)
    x = Activation('relu')(x)

    # local e2 = e1 - nn.SpatialConvolution(ngf, ngf*2, ks, ks, 2, 2, 1, 1) - normalization(ngf*2) - nn.ReLU(true)
    x = Conv2D(ngf*2, (ks,ks), strides=(2,2), padding='same')(x)
    x = normalize()(x)
    x = Activation('relu')(x)

    # local e3 = e2 - nn.SpatialConvolution(ngf*2, ngf*4, ks, ks, 2, 2, 1, 1) - normalization(ngf*4) - nn.ReLU(true)
    x = Conv2D(ngf*4, (ks,ks), strides=(2,2), padding='same')(x)
    x = normalize()(x)
    x = Activation('relu')(x)

    # local d1 = e3 - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) 
    #  - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type)
    x = res_block(x, ngf*4)
    x = res_block(x, ngf*4)
    x = res_block(x, ngf*4)
    x = res_block(x, ngf*4)
    x = res_block(x, ngf*4)
    x = res_block(x, ngf*4)

    # local d2 = d1 - nn.SpatialFullConvolution(ngf*4, ngf*2, ks, ks, 2, 2, 1, 1,1,1) - normalization(ngf*2) - nn.ReLU(true)
    # x = Conv2DTranspose(ngf*2, (ks,ks), strides=(2,2), padding='same')(x)
    x = scaleup(x, ngf*2, (ks, ks), strides=(2,2), padding='same')
    x = normalize()(x)
    x = Activation('relu')(x)

    # local d3 = d2 - nn.SpatialFullConvolution(ngf*2, ngf, ks, ks, 2, 2, 1, 1,1,1) - normalization(ngf) - nn.ReLU(true)
    # x = Conv2DTranspose(ngf, (ks,ks), strides=(2,2), padding='same')(x)
    x = scaleup(x, ngf, (ks, ks), strides=(2,2), padding='same')
    x = normalize()(x)
    x = Activation('relu')(x)

    # local d4 = d3 - nn.SpatialReflectionPadding(p, p, p, p) - nn.SpatialConvolution(ngf, output_nc, f, f, 1, 1) - nn.Tanh()
    x = padding((p,p))(x)
    x = Conv2D(output_nc, (f,f))(x)
    x = Activation('tanh')(x)

    model = Model(input, x, name=kwargs.get('name',None))
    print('Model resnet 6blocks:')
    model.summary()
    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号