def get_callbacks(log_dir=None, valid=(), tensorboard=True, eary_stopping=True):
"""Get callbacks.
Args:
log_dir (str): the destination to save logs(for TensorBoard).
valid (tuple): data for validation.
tensorboard (bool): Whether to use tensorboard.
eary_stopping (bool): whether to use early stopping.
Returns:
list: list of callbacks
"""
callbacks = []
if log_dir and tensorboard:
if not os.path.exists(log_dir):
print('Successfully made a directory: {}'.format(log_dir))
os.mkdir(log_dir)
callbacks.append(TensorBoard(log_dir))
if valid:
callbacks.append(F1score(*valid))
if log_dir:
if not os.path.exists(log_dir):
print('Successfully made a directory: {}'.format(log_dir))
os.mkdir(log_dir)
file_name = '_'.join(['model_weights', '{epoch:02d}', '{f1:2.2f}']) + '.h5'
save_callback = ModelCheckpoint(os.path.join(log_dir, file_name),
monitor='f1',
save_weights_only=True)
callbacks.append(save_callback)
if eary_stopping:
callbacks.append(EarlyStopping(monitor='f1', patience=3, mode='max'))
return callbacks
评论列表
文章目录