def train(self, blueprint, device,
save_best_model=False, model_filename=None):
try:
model = self.model_builder.build(
blueprint,
device)
setup_tf_session(device)
nb_epoch, callbacks = self._get_stopping_parameters(blueprint)
if save_best_model:
callbacks.append(self._get_model_save_callback(
model_filename,
blueprint.training.metric.metric))
start = time()
history = model.fit_generator(
self.batch_iterator,
self.batch_iterator.samples_per_epoch,
nb_epoch,
callbacks=callbacks,
validation_data=self.test_batch_iterator,
nb_val_samples=self.test_batch_iterator.sample_count)
if save_best_model:
del model
model = load_keras_model(model_filename)
return model, history, (time() - start)
except Exception as ex:
logging.debug(ex)
logging.debug(traceback.format_exc())
try:
from keras import backend
backend.clear_session()
except:
logging.debug(ex)
logging.debug(traceback.format_exc())
return None, None, 0
评论列表
文章目录