def train(Xtrain, ytrain, Xtrain_norm, ytrain_norm, Xvalidate, yvalidate, space):
import sys
from keras.optimizers import RMSprop
from keras.callbacks import Callback
class CorrelationEarlyStopping(Callback):
def __init__(self, monitor='validate', patience=0, delta=.001):
"""
:param monitor: 'validate' or 'train'
:param patience: how many epochs to wait
:param delta: by how much the monitored value has to be greater than the last maximum
"""
self.rvalues = {'train': [], 'validate': []}
self.monitor = monitor # validate, train
self.patience = patience
self.delta = delta
self.wait = 0
self.best = 0
self.num_epochs = 0
self.best_model = None
def on_epoch_end(self, epoch, logs={}):
r2 = get_metrics(self.model, x=Xtrain_norm, y=ytrain_norm)
self.rvalues['train'].append(r2)
r2 = get_metrics(self.model, x=Xvalidate, y=yvalidate)
self.rvalues['validate'].append(r2)
print ('\n\tTrain r2: {}\n\tValidate r2: {}\n'.format(self.rvalues['train'][-1], self.rvalues['validate'][-1]))
sys.stdout.flush()
if self.rvalues[self.monitor][-1] - self.delta >= self.best:
self.best = self.rvalues[self.monitor][-1]
self.wait = 0
self.num_epochs = epoch
self.best_model = self.model
else:
if self.wait >= self.patience:
self.num_epochs = epoch - self.patience
self.model.stop_training = True
else:
self.num_epochs = epoch
self.wait += 1
model = vgg_variant(space)
lr = 10**(-space['learning_rate'])
rmsprop = RMSprop(lr=lr, rho=0.9, epsilon=1e-08)
model.compile(loss='mean_squared_error', optimizer=rmsprop)
monitor = CorrelationEarlyStopping(monitor='validate', patience=6, delta=0.01)
gen = data_generator(Xtrain, ytrain, batch_size=space['batch_size'], space=space,
weighted_sampling=space['weighted_sampling'], augment=space['augment'],
sampling_factor=space['sampling_factor'], sampling_intercept=space['sampling_intercept'])
model.fit_generator(gen, space['samples_per_epoch'], 50, 1, [monitor], (Xvalidate, yvalidate))
return monitor.best_model, monitor.rvalues
python类Callback()的实例源码
def keras(self):
"""
Returns an object that implements the Keras Callback interface.
This method initializes the Keras callback lazily to to prevent
any possible import issues from affecting users who don't use it,
as well as prevent it from importing Keras/tensorflow and all of
their accompanying baggage unnecessarily in the case that they
happened to be installed, but the user is not using them.
"""
cb = self._callbacks.get(KERAS)
# Keras is not importable
if cb is False:
return None
# If this is the first time, try and import Keras
if not cb:
# Check if Keras is installed and fallback gracefully
try:
from keras.callbacks import Callback as KerasCallback
class _KerasCallback(KerasCallback):
"""_KerasCallback implement KerasCallback using an injected Experiment.
# TODO: Decide if we want to handle the additional callbacks:
# 1) on_epoch_begin
# 2) on_batch_begin
# 3) on_batch_end
# 4) on_train_begin
# 5) on_train_end
"""
def __init__(self, exp):
super(_KerasCallback, self).__init__()
self._exp = exp
def on_epoch_end(self, epoch, logs=None):
if not logs:
logs = {}
val_acc = logs.get("val_acc")
val_loss = logs.get("val_loss")
if val_acc is not None:
self._exp.metric("val_acc", val_acc)
if val_loss is not None:
self._exp.metric("val_loss", val_loss)
cb = _KerasCallback(self._exp)
self._callbacks[KERAS] = cb
return cb
except ImportError:
# Mark Keras as unimportable for future calls
self._callbacks[KERAS] = False
return None
return cb
# Version of Experiment with a different name for use internally, should not be used directly by consumers
def __init__(self, model, log_dir, histogram_freq=0, image_freq=0, audio_freq=0, write_graph=False):
super(Callback, self).__init__()
if K._BACKEND != 'tensorflow':
raise Exception('TensorBoardBatch callback only works '
'with the TensorFlow backend.')
import tensorflow as tf
self.tf = tf
import keras.backend.tensorflow_backend as KTF
self.KTF = KTF
self.log_dir = log_dir
self.histogram_freq = histogram_freq
self.image_freq = image_freq
self.audio_freq = audio_freq
self.histograms = None
self.images = None
self.write_graph = write_graph
self.iter = 0
self.scalars = []
self.images = []
self.audios = []
self.model = model
self.sess = KTF.get_session()
if self.histogram_freq != 0:
layers = self.model.layers
for layer in layers:
if hasattr(layer, 'name'):
layer_name = layer.name
else:
layer_name = layer
if hasattr(layer, 'W'):
name = '{}_W'.format(layer_name)
tf.histogram_summary(name, layer.W, collections=['histograms'])
if hasattr(layer, 'b'):
name = '{}_b'.format(layer_name)
tf.histogram_summary(name, layer.b, collections=['histograms'])
if hasattr(layer, 'output'):
name = '{}_out'.format(layer_name)
tf.histogram_summary(name, layer.output, collections=['histograms'])
if self.image_freq != 0:
tf.image_summary('input', self.model.input, max_images=2, collections=['images'])
tf.image_summary('output', self.model.output, max_images=2, collections=['images'])
if self.audio_freq != 0:
tf.audio_summary('input', self.model.input, max_outputs=1, collections=['audios'])
tf.audio_summary('output', self.model.output, max_outputs=1, collections=['audios'])
if self.write_graph:
if self.tf.__version__ >= '0.8.0':
self.writer = self.tf.train.SummaryWriter(self.log_dir, self.sess.graph)
else:
self.writer = self.tf.train.SummaryWriter(self.log_dir, self.sess.graph_def)
else:
self.writer = self.tf.train.SummaryWriter(self.log_dir)