def train(self, train_batches, valid_batches, samples_per_epoch, nb_epoch, nb_val_samples, extra_callbacks=None):
"""Train the model.
Automatically adds the following Keras callbacks:
- ModelCheckpoint
- EarlyStopping
- TensorBoard
Args:
train_batches (Iterable[Batch]): an iterable of training Batches
valid_batches (Iterable[Batch]): an iterable of validation Batches
samples_per_epoch (int)
nb_epoch (int): max number of epochs to train for
nb_val_samples (int): number of samples for validation
extra_callbacks (list): a list of additional Keras callbacks to run
"""
checkpoint_path = join(self.checkpoint_dir, 'weights.{epoch:02d}-{val_loss:.2f}.hdf5')
checkpointer = ModelCheckpoint(checkpoint_path, verbose=1, save_best_only=False)
early_stopper = EarlyStopping(monitor='val_loss', patience=2, verbose=1)
tboard = TensorBoard(self.tensorboard_dir, write_graph=False)
callbacks = [checkpointer, early_stopper, tboard]
if extra_callbacks:
callbacks.extend(extra_callbacks)
train = self._vectorized_batches(train_batches)
valid = self._vectorized_batches(valid_batches)
self.keras_model.fit_generator(train, samples_per_epoch, nb_epoch,
callbacks=callbacks,
validation_data=valid, nb_val_samples=nb_val_samples
)
评论列表
文章目录