def keras(self):
"""
Returns an object that implements the Keras Callback interface.
This method initializes the Keras callback lazily to to prevent
any possible import issues from affecting users who don't use it,
as well as prevent it from importing Keras/tensorflow and all of
their accompanying baggage unnecessarily in the case that they
happened to be installed, but the user is not using them.
"""
cb = self._callbacks.get(KERAS)
# Keras is not importable
if cb is False:
return None
# If this is the first time, try and import Keras
if not cb:
# Check if Keras is installed and fallback gracefully
try:
from keras.callbacks import Callback as KerasCallback
class _KerasCallback(KerasCallback):
"""_KerasCallback implement KerasCallback using an injected Experiment.
# TODO: Decide if we want to handle the additional callbacks:
# 1) on_epoch_begin
# 2) on_batch_begin
# 3) on_batch_end
# 4) on_train_begin
# 5) on_train_end
"""
def __init__(self, exp):
super(_KerasCallback, self).__init__()
self._exp = exp
def on_epoch_end(self, epoch, logs=None):
if not logs:
logs = {}
val_acc = logs.get("val_acc")
val_loss = logs.get("val_loss")
if val_acc is not None:
self._exp.metric("val_acc", val_acc)
if val_loss is not None:
self._exp.metric("val_loss", val_loss)
cb = _KerasCallback(self._exp)
self._callbacks[KERAS] = cb
return cb
except ImportError:
# Mark Keras as unimportable for future calls
self._callbacks[KERAS] = False
return None
return cb
# Version of Experiment with a different name for use internally, should not be used directly by consumers
评论列表
文章目录