keras_cnn_finetune.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:geom_rcnn 作者: asbroad 项目源码 文件源码
def make_model(self):
        # create the base pre-trained model
        if self.model_architecture == 'vgg16':
            from keras.applications.vgg16 import VGG16
            self.base_model = VGG16(weights='imagenet', include_top=False)
        elif self.model_architecture == 'resnet':
            from keras.applications.resnet50 import ResNet50
            self.base_model = ResNet50(weights='imagenet', include_top=False)
        elif self.model_architecture == 'inception':
            from keras.applications.inception_v3 import InceptionV3
            self.base_model = InceptionV3(weights='imagenet', include_top=False)
        else:
            print 'Model architecture parameter unknown. Options are: vgg16, resnet, and inception'
            rospy.signal_shutdown("Model architecture unknown.")

        # now we add a new dense layer to the end of the network inplace of the old layers
        x = self.base_model.output
        x = GlobalAveragePooling2D()(x)
        x = Dense(1024, activation='relu')(x)
        # add the outplut layer
        predictions = Dense(len(self.categories.keys()), activation='softmax')(x)

        # create new model composed of pre-trained network and new final layers
        # if you want to change the input size, you can do this with the input parameter below
        self.model = Model(input=self.base_model.input, output=predictions)

        # now we go through and freeze all of the layers that were pretrained
        for layer in self.base_model.layers:
            layer.trainable = False

        if self.verbose:
            print 'compiling model ... '
            start_time = time.time()

        # in finetuning, these parameters can matter a lot, it is wise to observe 
        # how well your model is learning for this to work well
        self.model.compile(optimizer=self.optimizer, loss='categorical_crossentropy', metrics=['accuracy'])

        if self.verbose:
            end_time = time.time()
            self.print_time(start_time,end_time,'compiling model')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号