hyper_gradients.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:RFHO 作者: lucfra 项目源码 文件源码
def run(self, T, train_feed_dict_supplier=None, val_feed_dict_suppliers=None,
            hyper_constraints_ops=None,
            _debug_no_hyper_update=False):  # TODO add session parameter
        """

        :param _debug_no_hyper_update: 
        :param T: number of steps
        :param train_feed_dict_supplier:
        :param val_feed_dict_suppliers:
        :param hyper_constraints_ops: (list of) either callable (no parameters) or tensorflow ops
        :return:
        """
        # idea: if steps == T then do full reverse, or forward, otherwise do trho and rtho
        # after all the main difference is that if we go with the full version, after the gradient has been
        # computed, the method `initialize()` is called.

        self.hyper_gradients.run_all(T, train_feed_dict_supplier=train_feed_dict_supplier,
                                     val_feed_dict_suppliers=val_feed_dict_suppliers,
                                     hyper_batch_step=self.hyper_batch_step.eval())
        if not _debug_no_hyper_update:
            [tf.get_default_session().run(hod.assign_ops) for hod in self.hyper_optimizers]
            if hyper_constraints_ops: [op() if callable(op) else op.eval()
                                       for op in as_list(hyper_constraints_ops)]

            self.hyper_batch_step.increase.eval()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号