inceptionv4.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:dogsVScats 作者: prajwalkr 项目源码 文件源码
def inception_v4():
    '''
    Creates the inception v4 network

    Args:
        num_classes: number of classes
        dropout_keep_prob: float, the fraction to keep before final layer.

    Returns: 
        logits: the logits outputs of the model.
    '''

    # Input Shape is 299 x 299 x 3 (tf) or 3 x 299 x 299 (th)

    if K.image_dim_ordering() == 'th':
        inputs = Input((3, 299, 299))
    else:
        inputs = Input((299, 299, 3))

    # Make inception base
    net = inception_v4_base(inputs)

    # Final pooling and prediction

    # 8 x 8 x 1536
    net = AveragePooling2D((8,8), border_mode='valid')(net)

    # 1 x 1 x 1536
    net = Flatten()(net)

    # 1536
    predictions = Dense(output_dim=1001, activation='softmax')(net)

    model = Model(inputs, predictions, name='inception_v4')

    model.load_weights(TF_WEIGHTS_PATH, by_name=True)

    model = pop_layer(model)
    # batchnormed = BatchNormalization(axis=3) ()
    # dense = Dense(128) (model.layers[-1].output)
    # batchnormed = BatchNormalization() (model.layers[-1].output)
    # relu = Activation('relu') (batchnormed)
    # dropout = Dropout(0.5) (relu)
    predictions = Dense(output_dim=1, activation='sigmoid')(model.layers[-1].output)

    model = Model(inputs, predictions, name='inception_v4')

    for layer in model.layers: 
        layer.trainable = False
        if layer.name == 'merge_25': break

    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号