def run(self, T, train_feed_dict_supplier=None, val_feed_dict_suppliers=None,
hyper_constraints_ops=None,
_debug_no_hyper_update=False): # TODO add session parameter
"""
:param _debug_no_hyper_update:
:param T: number of steps
:param train_feed_dict_supplier:
:param val_feed_dict_suppliers:
:param hyper_constraints_ops: (list of) either callable (no parameters) or tensorflow ops
:return:
"""
# idea: if steps == T then do full reverse, or forward, otherwise do trho and rtho
# after all the main difference is that if we go with the full version, after the gradient has been
# computed, the method `initialize()` is called.
self.hyper_gradients.run_all(T, train_feed_dict_supplier=train_feed_dict_supplier,
val_feed_dict_suppliers=val_feed_dict_suppliers,
hyper_batch_step=self.hyper_batch_step.eval())
if not _debug_no_hyper_update:
[tf.get_default_session().run(hod.assign_ops) for hod in self.hyper_optimizers]
if hyper_constraints_ops: [op() if callable(op) else op.eval()
for op in as_list(hyper_constraints_ops)]
self.hyper_batch_step.increase.eval()
评论列表
文章目录