def load_params(self, path, include_global=True, include_local=False):
"""Load network parameters from the given ``path``.
Parameters
----------
path : str
Filepath of parameter input file
include_global : bool, optional
If True (default True), load global (trainable) network variables
include_local : bool, optional
If True (default False), load local (non-trainable) network
variables
"""
if self.closed:
raise SimulationError("Simulation has been closed, cannot load "
"parameters")
with self.tensor_graph.graph.as_default():
vars = []
if include_global:
vars.extend(tf.global_variables())
if include_local:
vars.extend(tf.local_variables())
tf.train.Saver(vars).restore(self.sess, path)
logger.info("Model parameters loaded from %s", path)
评论列表
文章目录