def get_restore_op(self):
"""
Get variable restoring ngraph op from TF model checkpoint
Returns:
A `ng.doall` op that restores the stored weights in TF model
checkpoint
"""
if self._graph is None:
raise ValueError("self._graph is None, import meta_graph first.")
if self._checkpoint_path is None:
raise ValueError("self._checkpoint_path is None, please specify"
"checkpoint_path while importing meta_graph.")
with self._graph.as_default():
tf_variables = tf.global_variables()
ng_variables = self.get_op_handle(tf_variables)
ng_restore_ops = []
with tf.Session() as sess:
checkpoint_path = os.path.join(os.getcwd(),
self._checkpoint_path)
self.saver.restore(sess, checkpoint_path)
for tf_variable, ng_variable in zip(tf_variables, ng_variables):
val = sess.run(tf_variable)
ng_restore_ops.append(ng.assign(ng_variable, val))
return ng.doall(ng_restore_ops)
评论列表
文章目录