densenet_models.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:EvadeML-Zoo 作者: mzweilin 项目源码 文件源码
def densenet_cifar10_model(logits=False, input_range_type=1, pre_filter=lambda x:x):
    assert input_range_type == 1

    batch_size = 64
    nb_classes = 10

    img_rows, img_cols = 32, 32
    img_channels = 3

    img_dim = (img_channels, img_rows, img_cols) if K.image_dim_ordering() == "th" else (img_rows, img_cols, img_channels)
    depth = 40
    nb_dense_block = 3
    growth_rate = 12
    nb_filter = 16
    dropout_rate = 0.0 # 0.0 for data augmentation
    input_tensor = None
    include_top=True

    if logits is True:
        activation = None
    else:
        activation = "softmax"

    # Determine proper input shape
    input_shape = _obtain_input_shape(img_dim,
                                      default_size=32,
                                      min_size=8,
                                      data_format=K.image_data_format(),
                                      include_top=include_top)

    if input_tensor is None:
        img_input = Input(shape=input_shape)
    else:
        if not K.is_keras_tensor(input_tensor):
            img_input = Input(tensor=input_tensor, shape=input_shape)
        else:
            img_input = input_tensor

    x = __create_dense_net(nb_classes, img_input, True, depth, nb_dense_block,
                           growth_rate, nb_filter, -1, False, 0.0,
                           dropout_rate, 1E-4, activation)

    # Ensure that the model takes into account
    # any potential predecessors of `input_tensor`.
    if input_tensor is not None:
        inputs = get_source_inputs(input_tensor)
    else:
        inputs = img_input
    # Create model.
    model = Model(inputs, x, name='densenet')
    return model


# Source: https://github.com/titu1994/DenseNet
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号