def train_model(self, batch_size=32, nb_epoch=50,load_data = False,old_weight_path=''):
print("start training model...")
if load_data:
train_data, train_labels, valid_data, valid_labels = self.load_data()
else:
train_data, train_labels, valid_data, valid_labels = self.prepare_train_data()
model = self.baseModel()
if old_weight_path != '':
print("load last epoch model to continue train")
model.load_weights(old_weight_path)
model.fit(train_data, train_labels, batch_size=batch_size,
epochs=nb_epoch,
validation_data=(valid_data, valid_labels),
callbacks=[ModelCheckpoint("output/weights.{epoch:02d}-{val_loss:.2f}.hdf5",
monitor='val_loss',
verbose=1,
save_best_only=True, save_weights_only=False, mode='min', period=2),
ProgbarLogger()])
return model
评论列表
文章目录