def feature_extractor(FLAGS, suffix = ""):
weights = FLAGS.weights if FLAGS.weights != "random" else None
if FLAGS.model == "vgg16":
from keras.applications.vgg16 import VGG16
feature_extractor = VGG16(weights = weights)
remove_last_layer(feature_extractor)
elif FLAGS.model == "vgg19":
from keras.applications.vgg19 import VGG19
feature_extractor = VGG19(weights = weights)
remove_last_layer(feature_extractor)
elif FLAGS.model == "resnet50":
from keras.applications.resnet50 import ResNet50
feature_extractor = ResNet50(weights = weights)
remove_last_layer(feature_extractor)
else:
raise NotImplementedError
feature_extractor.name = FLAGS.model + suffix
if FLAGS.regularizer == "l2":
add_regularizer(feature_extractor)
elif FLAGS.regularizer != "none":
raise NotImplementedError
return feature_extractor
评论列表
文章目录