importer.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:ngraph 作者: NervanaSystems 项目源码 文件源码
def get_restore_op(self):
        """
        Get variable restoring ngraph op from TF model checkpoint

        Returns:
            A `ng.doall` op that restores the stored weights in TF model
            checkpoint
        """
        if self._graph is None:
            raise ValueError("self._graph is None, import meta_graph first.")
        if self._checkpoint_path is None:
            raise ValueError("self._checkpoint_path is None, please specify"
                             "checkpoint_path while importing meta_graph.")
        with self._graph.as_default():
            tf_variables = tf.global_variables()
            ng_variables = self.get_op_handle(tf_variables)
            ng_restore_ops = []
            with tf.Session() as sess:
                checkpoint_path = os.path.join(os.getcwd(),
                                               self._checkpoint_path)
                self.saver.restore(sess, checkpoint_path)
                for tf_variable, ng_variable in zip(tf_variables, ng_variables):
                    val = sess.run(tf_variable)
                    ng_restore_ops.append(ng.assign(ng_variable, val))
            return ng.doall(ng_restore_ops)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号