def fast_st_ps(input_shape, weights_path=None, mode=0, nb_res_layer=5):
input = Input(shape=input_shape, name='input_node', dtype=K.floatx())
# Downsampling
p11 = ReflectPadding2D(padding=(4, 4))(input)
c11 = Convolution2D(32, 9, 9, dim_ordering=K.image_dim_ordering(),
init='he_normal', subsample=(1, 1), border_mode='valid', activation='linear')(p11)
bn11 = InstanceNormalization('inorm-1')(c11)
a11 = Activation('relu')(bn11)
p12 = ReflectPadding2D(padding=(1, 1))(a11)
c12 = Convolution2D(64, 3, 3, dim_ordering=K.image_dim_ordering(),
init='he_normal', subsample=(2, 2), border_mode='valid', activation='linear')(p12)
bn12 = InstanceNormalization('inorm-2')(c12)
a12 = Activation('relu')(bn12)
p13 = ReflectPadding2D(padding=(1, 1))(a12)
c13 = Convolution2D(128, 3, 3, dim_ordering=K.image_dim_ordering(),
init='he_normal', subsample=(2, 2), border_mode='valid', activation='linear')(p13)
bn13 = InstanceNormalization('inorm-3')(c13)
last_out = Activation('relu')(bn13)
for i in range(nb_res_layer):
p = ReflectPadding2D(padding=(1, 1))(last_out)
c = Convolution2D(128, 3, 3, dim_ordering=K.image_dim_ordering(),
init='he_normal', subsample=(1, 1), border_mode='valid', activation='linear')(p)
bn = InstanceNormalization('inorm-res-%d' % i)(c)
a = Activation('relu')(bn)
p = ReflectPadding2D(padding=(1, 1))(a)
c = Convolution2D(128, 3, 3, dim_ordering=K.image_dim_ordering(),
init='he_normal', subsample=(1, 1), border_mode='valid', activation='linear')(p)
bn = InstanceNormalization('inorm-5-%d' % i)(c)
# a = Activation('relu')(bn)
last_out = merge([last_out, bn], mode='sum')
# last_out = a
out = PhaseShift(ratio=4, color=False)(last_out)
out = ReflectPadding2D(padding=(4, 4))(out)
out = Convolution2D(3, 9, 9, dim_ordering=K.image_dim_ordering(),
init='he_normal', subsample=(1, 1), border_mode='valid', activation='linear')(out)
out = ScaledSigmoid(scaling=255., name="output_node")(out)
model = Model(input=[input], output=[out])
if weights_path:
model.load_weights(weights_path)
return model
评论列表
文章目录