def st_convt(input_shape, weights_path=None, mode=0, nb_res_layer=5):
if K.image_dim_ordering() == 'tf':
channel_axis = 3
else:
channel_axis = 1
input = Input(shape=input_shape, name='input_node', dtype=K.floatx())
# Downsampling
c11 = Convolution2D(32, 9, 9, dim_ordering=K.image_dim_ordering(),
init='he_normal', subsample=(1, 1), border_mode='same', activation='linear')(input)
bn11 = BatchNormalization(mode=mode, axis=channel_axis, momentum=0.1, gamma_init='he_normal')(c11)
a11 = Activation('relu')(bn11)
c12 = Convolution2D(64, 3, 3, dim_ordering=K.image_dim_ordering(),
init='he_normal', subsample=(2, 2), border_mode='same', activation='linear')(a11)
bn12 = BatchNormalization(mode=mode, axis=channel_axis, momentum=0.1, gamma_init='he_normal')(c12)
a12 = Activation('relu')(bn12)
c13 = Convolution2D(128, 3, 3, dim_ordering=K.image_dim_ordering(),
init='he_normal', subsample=(2, 2), border_mode='same', activation='linear')(a12)
bn13 = BatchNormalization(mode=mode, axis=channel_axis, momentum=0.1, gamma_init='he_normal')(c13)
last_out = Activation('relu')(bn13)
for i in range(nb_res_layer):
c = Convolution2D(128, 3, 3, dim_ordering=K.image_dim_ordering(),
init='he_normal', subsample=(1, 1), border_mode='same', activation='linear')(last_out)
bn = BatchNormalization(mode=mode, axis=channel_axis, momentum=0.1, gamma_init='he_normal')(c)
a = Activation('relu')(bn)
c = Convolution2D(128, 3, 3, dim_ordering=K.image_dim_ordering(),
init='he_normal', subsample=(1, 1), border_mode='same', activation='linear')(a)
bn = BatchNormalization(mode=mode, axis=channel_axis, momentum=0.1, gamma_init='he_normal')(c)
# a = Activation('relu')(bn)
last_out = merge([last_out, bn], mode='sum')
# last_out = a
ct71 = ConvolutionTranspose2D(64, 3, 3, dim_ordering=K.image_dim_ordering(),
init='he_normal', subsample=(2, 2), border_mode='same', activation='linear')(last_out)
bn71 = BatchNormalization(mode=mode, axis=channel_axis, momentum=0.1, gamma_init='he_normal')(ct71)
a71 = Activation('relu')(bn71)
ct81 = ConvolutionTranspose2D(32, 3, 3, dim_ordering=K.image_dim_ordering(),
init='he_normal', subsample=(2, 2), border_mode='same', activation='linear')(a71)
bn81 = BatchNormalization(mode=mode, axis=channel_axis, momentum=0.1, gamma_init='he_normal')(ct81)
a81 = Activation('relu')(bn81)
c91 = Convolution2D(3, 9, 9, dim_ordering=K.image_dim_ordering(),
init='he_normal', subsample=(1, 1), border_mode='same', activation='linear')(a81)
out = ScaledSigmoid(scaling=255., name="output_node")(c91)
model = Model(input=[input], output=[out])
if weights_path:
model.load_weights(weights_path)
return model
# Moving from 4 to 12 layers doesn't seem to improve much
评论列表
文章目录