def save_params(self, path, include_global=True, include_local=False):
"""Save network parameters to the given ``path``.
Parameters
----------
path : str
Filepath of parameter output file
include_global : bool, optional
If True (default True), save global (trainable) network variables
include_local : bool, optional
If True (default False), save local (non-trainable) network
variables
"""
if self.closed:
raise SimulationError("Simulation has been closed, cannot save "
"parameters")
with self.tensor_graph.graph.as_default():
vars = []
if include_global:
vars.extend(tf.global_variables())
if include_local:
vars.extend(tf.local_variables())
path = tf.train.Saver(vars).save(self.sess, path)
logger.info("Model parameters saved to %s", path)
评论列表
文章目录