def train(self):
model = self.model_module.build_model(IRMAS_N_CLASSES)
early_stopping = EarlyStopping(monitor='val_loss', patience=EARLY_STOPPING_EPOCH)
save_clb = ModelCheckpoint(
"{weights_basepath}/{model_path}/".format(
weights_basepath=MODEL_WEIGHT_BASEPATH,
model_path=self.model_module.BASE_NAME) +
"epoch.{epoch:02d}-val_loss.{val_loss:.3f}-fbeta.{val_fbeta_score:.3f}"+"-{key}.hdf5".format(
key=self.model_module.MODEL_KEY),
monitor='val_loss',
save_best_only=True)
lrs = LearningRateScheduler(lambda epoch_n: self.init_lr / (2**(epoch_n//SGD_LR_REDUCE)))
model.summary()
model.compile(optimizer=self.optimizer,
loss='categorical_crossentropy',
metrics=['accuracy', fbeta_score])
history = model.fit_generator(self._batch_generator(self.X_train, self.y_train),
samples_per_epoch=self.model_module.SAMPLES_PER_EPOCH,
nb_epoch=MAX_EPOCH_NUM,
verbose=2,
callbacks=[save_clb, early_stopping, lrs],
validation_data=self._batch_generator(self.X_val, self.y_val),
nb_val_samples=self.model_module.SAMPLES_PER_VALIDATION,
class_weight=None,
nb_worker=1)
pickle.dump(history.history, open('{history_basepath}/{model_path}/history_{model_key}.pkl'.format(
history_basepath=MODEL_HISTORY_BASEPATH,
model_path=self.model_module.BASE_NAME,
model_key=self.model_module.MODEL_KEY),
'w'))
评论列表
文章目录