retina_net.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:qtim_ROP 作者: QTIM-Lab 项目源码 文件源码
def _configure_network(self, build=True):

        network = self.config['network']
        type_, weights = network['type'].lower(), network.get('weights', None)
        fine_tuning = " with pre-trained weights '{}'".format(weights) if weights else " without pre-training"

        if 'vgg' in type_:

            from keras.applications.vgg16 import VGG16
            logging.info("Instantiating VGG model" + fine_tuning)
            self.model = VGG16(weights=weights, input_shape=(3, 227, 227), include_top=True)

        elif 'resnet' in type_:

            from keras.applications.resnet50 import ResNet50
            logging.info("Instantiating ResNet model" + fine_tuning)

            input_layer = Input(shape=(3, 224, 224))
            base_model = ResNet50(weights=weights, include_top=False, input_tensor=input_layer)

            x = base_model.output
            x = Flatten()(x)
            x = Dense(1024, activation='relu')(x)
            x = Dropout(0.5)(x)
            predictions = Dense(3, activation='softmax')(x)

            self.model = Model(input=base_model.input, output=predictions)
            # for layer in base_model.layers:
            #     layer.trainable = fine_tuning

        else:

            if 'googlenet' in type_:
                custom_objects = {"PoolHelper": PoolHelper, "LRN": LRN}
                mod_str = 'GoogLeNet'
            else:
                custom_objects = {}
                mod_str = 'custom'

            from googlenet import create_googlenet
            logging.info("Instantiating {} model".format(mod_str) + fine_tuning)
            arch = network.get('arch', None)

            if arch is None:
                self.model = create_googlenet(network.get('no_classes', 3), network.get('no_features', 1024))
            else:
                self.model = model_from_json(open(arch).read(), custom_objects=custom_objects)

            if weights:
                print "Loading weights '{}'".format(weights)
                self.model.load_weights(weights, by_name=True)

        # Configure optimizer
        if build:
            opt_options = self.config['optimizer']
            name, loss, params = opt_options['type'], opt_options['loss'], opt_options['params']
            optimizer = OPTIMIZERS[name](**params)
            self.model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号