dl_resnet50.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:jamespy_py3 作者: jskDr 项目源码 文件源码
def __init__(self, input_shape, nb_classes, weights='imagenet'):
        base_model = ResNet50(weights=weights, include_top=False,
                              input_shape=input_shape)

        x = base_model.input
        h = base_model.output
        z_cl = h  # Saving for cl output monitoring.

        h = GlobalAveragePooling2D()(h)
        h = Dense(128, activation='relu')(h)
        h = Dropout(0.5)(h)
        z_fl = h  # Saving for fl output monitoring.

        y = Dense(nb_classes, activation='softmax', name='preds')(h)
        # y = Dense(4, activation='softmax')(h)

        for layer in base_model.layers:
            layer.trainable = False

        model = Model(x, y)
        model.compile(loss='categorical_crossentropy', 
                      optimizer='adadelta', metrics=['accuracy'])

        self.model = model
        self.cl_part = Model(x, z_cl)
        self.fl_part = Model(x, z_fl)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号