def densenet_cifar10_model(logits=False, input_range_type=1, pre_filter=lambda x:x):
assert input_range_type == 1
batch_size = 64
nb_classes = 10
img_rows, img_cols = 32, 32
img_channels = 3
img_dim = (img_channels, img_rows, img_cols) if K.image_dim_ordering() == "th" else (img_rows, img_cols, img_channels)
depth = 40
nb_dense_block = 3
growth_rate = 12
nb_filter = 16
dropout_rate = 0.0 # 0.0 for data augmentation
input_tensor = None
include_top=True
if logits is True:
activation = None
else:
activation = "softmax"
# Determine proper input shape
input_shape = _obtain_input_shape(img_dim,
default_size=32,
min_size=8,
data_format=K.image_data_format(),
include_top=include_top)
if input_tensor is None:
img_input = Input(shape=input_shape)
else:
if not K.is_keras_tensor(input_tensor):
img_input = Input(tensor=input_tensor, shape=input_shape)
else:
img_input = input_tensor
x = __create_dense_net(nb_classes, img_input, True, depth, nb_dense_block,
growth_rate, nb_filter, -1, False, 0.0,
dropout_rate, 1E-4, activation)
# Ensure that the model takes into account
# any potential predecessors of `input_tensor`.
if input_tensor is not None:
inputs = get_source_inputs(input_tensor)
else:
inputs = img_input
# Create model.
model = Model(inputs, x, name='densenet')
return model
# Source: https://github.com/titu1994/DenseNet
评论列表
文章目录