def init_callbacks(self, for_worker=False):
"""Prepares all keras callbacks to be used in training.
Automatically attaches a History callback to the end of the callback list.
If for_worker is True, leaves out callbacks that only make sense
with validation enabled."""
import keras.callbacks as cbks
remove_for_worker = [cbks.EarlyStopping, cbks.ModelCheckpoint]
if for_worker:
for obj in remove_for_worker:
self.callbacks_list = [ c for c in self.callbacks_list
if not isinstance(c, obj) ]
self.model.history = cbks.History()
self.callbacks = cbks.CallbackList( self.callbacks_list + [self.model.history] )
# it's possible to callback a different model than self
# (used by Sequential models)
if hasattr(self.model, 'callback_model') and self.model.callback_model:
self.callback_model = self.model.callback_model
else:
self.callback_model = self.model
self.callbacks.set_model(self.callback_model)
self.callback_model.stop_training = False
python类CallbackList()的实例源码
def fit(self, dataloader, nb_iter=None, nb_epoch=None, iter_per_epoch=None,
callbacks=[], verbose=0):
"""Trains the underlying Keras model.
Args:
dataloader (StandardDataLoader): Manages the loading of data to
model.
nb_iter (int): The number of iterations to train the model.
nb_epoch (int): The number of epochs to train the model.
iter_per_epoch (int): Defines the number of iterations per epoch.
callbacks (list): List of Keras callbacks to run during training.
"""
nb_iter, iter_per_epoch = self._get_iterations(
nb_iter, nb_epoch, iter_per_epoch)
callbacks = CallbackList(callbacks)
callbacks._set_model(self)
callbacks.on_train_begin()
try:
epoch = 0
self.stop_training = False
for i in xrange(nb_iter):
# Begin epoch
if i % iter_per_epoch == 0:
callbacks.on_epoch_begin(epoch)
# Execution
callbacks.on_batch_begin(i)
if verbose > 0:
import time
time.sleep(0.001)
j = i % iter_per_epoch
perc = int(100 * (j + 1) /iter_per_epoch)
prog = ''.join(['='] * (perc/2))
string = "[{:50s}] {:3d}%\r".format(prog, perc)
sys.stdout.write(string); sys.stdout.flush()
losses = self.keras_model.train_on_batch(
*dataloader.get_training_batch())
callbacks.on_batch_end(i)
# End epoch
if (i + 1) % iter_per_epoch == 0:
callbacks.on_epoch_end(epoch, logs={'losses': losses})
epoch += 1
if self.stop_training:
break
except KeyboardInterrupt:
print "\n[BayesNet] Abort: KeyboardInterrupt"
raise
callbacks.on_train_end()