def iterate_training(model, dataset, initial_epoch):
"""Iterative Training"""
checkpoint = ModelCheckpoint(MODEL_CHECKPOINT_DIRECTORYNAME + '/' + MODEL_CHECKPOINT_FILENAME,
save_best_only=True)
tensorboard = TensorBoard()
csv_logger = CSVLogger(CSV_LOG_FILENAME)
X_dev_batch, y_dev_batch = next(dataset.dev_set_batch_generator(1000))
show_samples_callback = LambdaCallback(
on_epoch_end=lambda epoch, logs: show_samples(model, dataset, epoch, logs, X_dev_batch, y_dev_batch))
train_batch_generator = dataset.train_set_batch_generator(BATCH_SIZE)
validation_batch_generator = dataset.dev_set_batch_generator(BATCH_SIZE)
model.fit_generator(train_batch_generator,
samples_per_epoch=SAMPLES_PER_EPOCH,
nb_epoch=NUMBER_OF_EPOCHS,
validation_data=validation_batch_generator,
nb_val_samples=SAMPLES_PER_EPOCH,
callbacks=[checkpoint, tensorboard, csv_logger, show_samples_callback],
verbose=1,
initial_epoch=initial_epoch)
评论列表
文章目录