cnn_inception_v3_context_classifier.py 文件源码

python
阅读 15 收藏 0 点赞 0 评论 0

项目:SerpentAI 作者: SerpentAI 项目源码 文件源码
def train(self):
        if self.training_generator is None or self.validation_generator is None:
            self.prepare_generators()

        base_model = InceptionV3(
            weights="imagenet",
            include_top=False,
            input_shape=self.input_shape
        )

        output = base_model.output
        output = GlobalAveragePooling2D()(output)
        output = Dense(1024, activation='relu')(output)

        predictions = Dense(len(self.training_generator.class_indices), activation='softmax')(output)
        self.classifier = Model(inputs=base_model.input, outputs=predictions)

        for layer in base_model.layers:
            layer.trainable = False

        self.classifier.compile(
            optimizer="rmsprop",
            loss="categorical_crossentropy",
            metrics=["accuracy"]
        )

        self.classifier.fit_generator(
            self.training_generator,
            samples_per_epoch=self.training_sample_count,
            nb_epoch=3,
            validation_data=self.validation_generator,
            nb_val_samples=self.validation_sample_count,
            class_weight="auto"
        )
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号