def create_callbacks(self, callback: Callable[[], None], tensor_board_log_directory: Path, net_directory: Path,
callback_step: int = 1, save_step: int = 1) -> List[Callback]:
class CustomCallback(Callback):
def on_epoch_end(self_callback, epoch, logs=()):
if epoch % callback_step == 0:
callback()
if epoch % save_step == 0 and epoch > 0:
mkdir(net_directory)
self.predictive_net.save_weights(str(net_directory / self.model_file_name(epoch)))
tensorboard_if_running_tensorflow = [TensorBoard(
log_dir=str(tensor_board_log_directory), write_images=True)] if backend.backend() == 'tensorflow' else []
return tensorboard_if_running_tensorflow + [CustomCallback()]
评论列表
文章目录