def __init__(self, optimizer, hyper_dict, method, hyper_grad_kwargs=None,
hyper_optimizer_class=AdamOptimizer, **optimizers_kwargs):
"""
Interface instance of gradient-based hyperparameter optimization methods.
:param optimizer: parameter optimization dynamics (obtained from `Optimizer.create` methods)
:param hyper_dict: dictionary of validation errors and list of hyperparameters to be optimized
:param method: method with which to compute hyper-gradients: Forward
or Reverse-Ho
:param hyper_grad_kwargs: dictionary of keyword arguments for `HyperGradient` classes (usually None)
:param hyper_optimizer_class: (default Adam) Optimizer class for optimization of the hyperparameters
:param optimizers_kwargs: keyword arguments for hyperparameter optimizers (like hyper-learning rate)
"""
assert method in [ReverseHG, ForwardHG]
assert hyper_optimizer_class is None or issubclass(hyper_optimizer_class, Optimizer)
assert isinstance(hyper_dict, dict)
assert isinstance(optimizer, Optimizer)
if not hyper_grad_kwargs: hyper_grad_kwargs = {}
self.hyper_iteration_step = GlobalStep(name='hyper_iteration_step')
self._report_hyper_it_init = tf.report_uninitialized_variables([self.hyper_iteration_step.var])
# self.hyper_batch_step = GlobalStep(name='hyper_batch_step')
self.hyper_batch_step = GlobalStep(name='batch_step')
# automatically links eventual optimizer global step (like in Adam) to HyperGradient global step
hyper_grad_kwargs['global_step'] = hyper_grad_kwargs.get(
'global_step', optimizer.global_step if hasattr(optimizer, 'global_step') else GlobalStep())
# automatically links eventual hyper-optimizer global step (like in Adam) to batch_step
if hyper_optimizer_class == AdamOptimizer:
optimizers_kwargs['global_step'] = self.hyper_batch_step
optimizers_kwargs.setdefault('eps', 1.e-14)
self.hyper_gradients = method(optimizer, hyper_dict, **hyper_grad_kwargs)
if hyper_optimizer_class:
# noinspection PyTypeChecker
self.hyper_optimizers = create_hyperparameter_optimizers(
self.hyper_gradients, optimizer_class=hyper_optimizer_class, **optimizers_kwargs)
else:
self.hyper_optimizers = None
评论列表
文章目录