def bottleneck(encoder, output, upsample=False, reverse_module=False):
internal = output // 4
x = Conv2D(internal, (1, 1), use_bias=False)(encoder)
x = BatchNormalization(momentum=0.1)(x)
# x = Activation('relu')(x)
x = PReLU(shared_axes=[1, 2])(x)
if not upsample:
x = Conv2D(internal, (3, 3), padding='same', use_bias=True)(x)
else:
x = Conv2DTranspose(filters=internal, kernel_size=(3, 3), strides=(2, 2), padding='same')(x)
x = BatchNormalization(momentum=0.1)(x)
# x = Activation('relu')(x)
x = PReLU(shared_axes=[1, 2])(x)
x = Conv2D(output, (1, 1), padding='same', use_bias=False)(x)
other = encoder
if encoder.get_shape()[-1] != output or upsample:
other = Conv2D(output, (1, 1), padding='same', use_bias=False)(other)
other = BatchNormalization(momentum=0.1)(other)
if upsample and reverse_module is not False:
other = MaxUnpooling2D()([other, reverse_module])
if upsample and reverse_module is False:
decoder = x
else:
x = BatchNormalization(momentum=0.1)(x)
decoder = add([x, other])
# decoder = Activation('relu')(decoder)
decoder = PReLU(shared_axes=[1, 2])(decoder)
return decoder
评论列表
文章目录