hyper_gradients.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:RFHO 作者: lucfra 项目源码 文件源码
def run_all(self, T, train_feed_dict_supplier=None, val_feed_dict_suppliers=None, hyper_batch_step=None,
                forward_su=None, after_forward_su=None):
        """
        Helper method for running

        :param hyper_batch_step: support for stochastic sampling of validation  set
        :param T:
        :param train_feed_dict_supplier:
        :param val_feed_dict_suppliers:
        :param forward_su:
        :param after_forward_su:
        :return:
        """

        # self.initialize()
        for k in range(T):
            self.step_forward(train_feed_dict_supplier=train_feed_dict_supplier, summary_utils=forward_su)

        if after_forward_su:
            after_forward_su.run(tf.get_default_session(), T)

        return self.hyper_gradients(val_feed_dict_suppliers, hyper_batch_step)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号