def get_unet0(num_start_filters=32):
inputs = Input((img_rows, img_cols, num_channels))
conv1 = ConvBN2(inputs, num_start_filters)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
conv2 = ConvBN2(pool1, 2 * num_start_filters)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
conv3 = ConvBN2(pool2, 4 * num_start_filters)
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
conv4 = ConvBN2(pool3, 8 * num_start_filters)
pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
conv5 = ConvBN2(pool4, 16 * num_start_filters)
up6 = concatenate([UpSampling2D(size=(2, 2))(conv5), conv4])
conv6 = ConvBN2(up6, 8 * num_start_filters)
up7 = concatenate([UpSampling2D(size=(2, 2))(conv6), conv3])
conv7 = ConvBN2(up7, 4 * num_start_filters)
up8 = concatenate([UpSampling2D(size=(2, 2))(conv7), conv2])
conv8 = ConvBN2(up8, 2 * num_start_filters)
up9 = concatenate([UpSampling2D(size=(2, 2))(conv8), conv1])
conv9 = Conv2D(num_start_filters, (3, 3), padding="same", kernel_initializer="he_uniform")(up9)
conv9 = BatchNormalization()(conv9)
conv9 = Activation('selu')(conv9)
conv9 = Conv2D(num_start_filters, (3, 3), padding="same", kernel_initializer="he_uniform")(conv9)
crop9 = Cropping2D(cropping=((16, 16), (16, 16)))(conv9)
conv9 = BatchNormalization()(crop9)
conv9 = Activation('selu')(conv9)
conv10 = Conv2D(num_mask_channels, (1, 1))(conv9)
model = Model(inputs=inputs, outputs=conv10)
return model
评论列表
文章目录