def resnet_6blocks(input_shape, output_nc, ngf, **kwargs):
ks = 3
f = 7
p = (f-1)/2
input = Input(input_shape)
# local e1 = data - nn.SpatialReflectionPadding(p, p, p, p) - nn.SpatialConvolution(3, ngf, f, f, 1, 1) - normalization(ngf) - nn.ReLU(true)
x = padding((p,p))(input)
x = Conv2D(ngf, (f,f),)(x)
x = normalize()(x)
x = Activation('relu')(x)
# local e2 = e1 - nn.SpatialConvolution(ngf, ngf*2, ks, ks, 2, 2, 1, 1) - normalization(ngf*2) - nn.ReLU(true)
x = Conv2D(ngf*2, (ks,ks), strides=(2,2), padding='same')(x)
x = normalize()(x)
x = Activation('relu')(x)
# local e3 = e2 - nn.SpatialConvolution(ngf*2, ngf*4, ks, ks, 2, 2, 1, 1) - normalization(ngf*4) - nn.ReLU(true)
x = Conv2D(ngf*4, (ks,ks), strides=(2,2), padding='same')(x)
x = normalize()(x)
x = Activation('relu')(x)
# local d1 = e3 - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type)
# - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type)
x = res_block(x, ngf*4)
x = res_block(x, ngf*4)
x = res_block(x, ngf*4)
x = res_block(x, ngf*4)
x = res_block(x, ngf*4)
x = res_block(x, ngf*4)
# local d2 = d1 - nn.SpatialFullConvolution(ngf*4, ngf*2, ks, ks, 2, 2, 1, 1,1,1) - normalization(ngf*2) - nn.ReLU(true)
# x = Conv2DTranspose(ngf*2, (ks,ks), strides=(2,2), padding='same')(x)
x = scaleup(x, ngf*2, (ks, ks), strides=(2,2), padding='same')
x = normalize()(x)
x = Activation('relu')(x)
# local d3 = d2 - nn.SpatialFullConvolution(ngf*2, ngf, ks, ks, 2, 2, 1, 1,1,1) - normalization(ngf) - nn.ReLU(true)
# x = Conv2DTranspose(ngf, (ks,ks), strides=(2,2), padding='same')(x)
x = scaleup(x, ngf, (ks, ks), strides=(2,2), padding='same')
x = normalize()(x)
x = Activation('relu')(x)
# local d4 = d3 - nn.SpatialReflectionPadding(p, p, p, p) - nn.SpatialConvolution(ngf, output_nc, f, f, 1, 1) - nn.Tanh()
x = padding((p,p))(x)
x = Conv2D(output_nc, (f,f))(x)
x = Activation('tanh')(x)
model = Model(input, x, name=kwargs.get('name',None))
print('Model resnet 6blocks:')
model.summary()
return model
评论列表
文章目录