dl_vgg.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:jamespy_py3 作者: jskDr 项目源码 文件源码
def build_model(model):
        nb_classes = model.nb_classes
        input_shape = model.in_shape
        # print(nb_classes)

        # base_model = VGG16(weights='imagenet', include_top=False)

        base_model = VGG16(weights='imagenet', include_top=False,
                           input_shape=input_shape)

        x = base_model.input
        h = base_model.output
        z_cl = h  # Saving for cl output monitoring.

        h = GlobalAveragePooling2D()(h)
        h = Dense(10, activation='relu')(h)
        z_fl = h  # Saving for fl output monitoring.

        y = Dense(nb_classes, activation='softmax', name='preds')(h)
        # y = Dense(4, activation='softmax')(h)

        for layer in base_model.layers:
            layer.trainable = False

        model.cl_part = Model(x, z_cl)
        model.fl_part = Model(x, z_fl)

        model.x = x
        model.y = y
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号