def _configure_network(self, build=True):
network = self.config['network']
type_, weights = network['type'].lower(), network.get('weights', None)
fine_tuning = " with pre-trained weights '{}'".format(weights) if weights else " without pre-training"
if 'vgg' in type_:
from keras.applications.vgg16 import VGG16
logging.info("Instantiating VGG model" + fine_tuning)
self.model = VGG16(weights=weights, input_shape=(3, 227, 227), include_top=True)
elif 'resnet' in type_:
from keras.applications.resnet50 import ResNet50
logging.info("Instantiating ResNet model" + fine_tuning)
input_layer = Input(shape=(3, 224, 224))
base_model = ResNet50(weights=weights, include_top=False, input_tensor=input_layer)
x = base_model.output
x = Flatten()(x)
x = Dense(1024, activation='relu')(x)
x = Dropout(0.5)(x)
predictions = Dense(3, activation='softmax')(x)
self.model = Model(input=base_model.input, output=predictions)
# for layer in base_model.layers:
# layer.trainable = fine_tuning
else:
if 'googlenet' in type_:
custom_objects = {"PoolHelper": PoolHelper, "LRN": LRN}
mod_str = 'GoogLeNet'
else:
custom_objects = {}
mod_str = 'custom'
from googlenet import create_googlenet
logging.info("Instantiating {} model".format(mod_str) + fine_tuning)
arch = network.get('arch', None)
if arch is None:
self.model = create_googlenet(network.get('no_classes', 3), network.get('no_features', 1024))
else:
self.model = model_from_json(open(arch).read(), custom_objects=custom_objects)
if weights:
print "Loading weights '{}'".format(weights)
self.model.load_weights(weights, by_name=True)
# Configure optimizer
if build:
opt_options = self.config['optimizer']
name, loss, params = opt_options['type'], opt_options['loss'], opt_options['params']
optimizer = OPTIMIZERS[name](**params)
self.model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
评论列表
文章目录