def train(self):
if self.training_generator is None or self.validation_generator is None:
self.prepare_generators()
base_model = InceptionV3(
weights="imagenet",
include_top=False,
input_shape=self.input_shape
)
output = base_model.output
output = GlobalAveragePooling2D()(output)
output = Dense(1024, activation='relu')(output)
predictions = Dense(len(self.training_generator.class_indices), activation='softmax')(output)
self.classifier = Model(inputs=base_model.input, outputs=predictions)
for layer in base_model.layers:
layer.trainable = False
self.classifier.compile(
optimizer="rmsprop",
loss="categorical_crossentropy",
metrics=["accuracy"]
)
self.classifier.fit_generator(
self.training_generator,
samples_per_epoch=self.training_sample_count,
nb_epoch=3,
validation_data=self.validation_generator,
nb_val_samples=self.validation_sample_count,
class_weight="auto"
)
cnn_inception_v3_context_classifier.py 文件源码
python
阅读 15
收藏 0
点赞 0
评论 0
评论列表
文章目录