def initialize_uninitialized_global_variables(sess):
"""
Only initializes the variables of a TensorFlow session that were not
already initialized.
:param sess: the TensorFlow session
:return:
"""
# List all global variables
global_vars = tf.global_variables()
# Find initialized status for all variables
is_var_init = [tf.is_variable_initialized(var) for var in global_vars]
is_initialized = sess.run(is_var_init)
# List all variables that were not initialized previously
not_initialized_vars = [var for (var, init) in
zip(global_vars, is_initialized) if not init]
# Initialize all uninitialized variables found, if any
if len(not_initialized_vars):
sess.run(tf.variables_initializer(not_initialized_vars))
python类is_variable_initialized()的实例源码
def initialize_uninitialized_variables(sess):
"""
Only initialize the weights that have not yet been initialized by other
means, such as importing a metagraph and a checkpoint. It's useful when
extending an existing model.
"""
uninit_vars = []
uninit_tensors = []
for var in tf.global_variables():
uninit_vars.append(var)
uninit_tensors.append(tf.is_variable_initialized(var))
uninit_bools = sess.run(uninit_tensors)
uninit = zip(uninit_bools, uninit_vars)
uninit = [var for init, var in uninit if not init]
sess.run(tf.variables_initializer(uninit))
#-------------------------------------------------------------------------------
def initialize_uninitialized(sess):
global_vars = tf.global_variables()
is_not_initialized = sess.run([tf.is_variable_initialized(var) for var in global_vars])
not_initialized_vars = [v for (v, f) in zip(global_vars, is_not_initialized) if not f]
print([str(i.name) for i in not_initialized_vars]) # only for testing
if len(not_initialized_vars):
sess.run(tf.variables_initializer(not_initialized_vars))
def _init_uninitialized(sess):
"""Initializes all uninitialized variables and returns them as a list."""
variables = tf.global_variables()
if not variables: return [] # sess.run() barfs on empty list
is_initialized = sess.run([tf.is_variable_initialized(v) for v in variables])
needs_init = [v for v, i in zip(variables, is_initialized) if not i]
if not needs_init: return []
sess.run(tf.variables_initializer(needs_init))
return needs_init
def testIsVariableInitialized(self):
for use_gpu in [True, False]:
with self.test_session(use_gpu=use_gpu):
v0 = state_ops.variable_op([1, 2], tf.complex64)
self.assertEqual(False, tf.is_variable_initialized(v0).eval())
tf.assign(v0, [[2.0+3.0j, 3.0+2.0j]]).eval()
self.assertEqual(True, tf.is_variable_initialized(v0).eval())
def _build(self):
tensor = self._build_parameter()
self._dataholder_tensor = tensor
self._is_initialized_tensor = tf.is_variable_initialized(tensor)
def _build(self):
unconstrained = self._build_parameter()
constrained = self._build_constrained(unconstrained)
prior = self._build_prior(unconstrained, constrained)
self._is_initialized_tensor = tf.is_variable_initialized(unconstrained)
self._unconstrained_tensor = unconstrained
self._constrained_tensor = constrained
self._prior_tensor = prior
def get_uninitialized_variables(variables=None):
"""Return a list of uninitialized tf variables.
Parameters
----------
variables: tf.Variable, list(tf.Variable), optional
Filter variable list to only those that are uninitialized. If no
variables are specified the list of all variables in the graph
will be used.
Returns
-------
list(tf.Variable)
List of uninitialized tf variables.
"""
sess = tf.get_default_session()
if variables is None:
variables = tf.global_variables()
else:
variables = list(variables)
if len(variables) == 0:
return []
if semver.match(tf.__version__, '<1.0.0'):
init_flag = sess.run(
tf.pack([tf.is_variable_initialized(v) for v in variables]))
else:
init_flag = sess.run(
tf.stack([tf.is_variable_initialized(v) for v in variables]))
return [v for v, f in zip(variables, init_flag) if not f]
def get_uninitialized_variables(variables=None):
"""Return a list of uninitialized tf variables.
Parameters
----------
variables: tf.Variable, list(tf.Variable), optional
Filter variable list to only those that are uninitialized. If no
variables are specified the list of all variables in the graph
will be used.
Returns
-------
list(tf.Variable)
List of uninitialized tf variables.
"""
sess = tf.get_default_session()
if variables is None:
variables = tf.global_variables()
else:
variables = list(variables)
if len(variables) == 0:
return []
if semver.match(tf.__version__, '<1.0.0'):
init_flag = sess.run(
tf.pack([tf.is_variable_initialized(v) for v in variables]))
else:
init_flag = sess.run(
tf.stack([tf.is_variable_initialized(v) for v in variables]))
return [v for v, f in zip(variables, init_flag) if not f]
# Tears of the debugging...
def get_uninitialized_variables(variables=None):
"""Return a list of uninitialized tf variables.
Parameters
----------
variables: tf.Variable, list(tf.Variable), optional
Filter variable list to only those that are uninitialized. If no
variables are specified the list of all variables in the graph
will be used.
Returns
-------
list(tf.Variable)
List of uninitialized tf variables.
"""
sess = tf.get_default_session()
if variables is None:
variables = tf.global_variables()
else:
variables = list(variables)
if len(variables) == 0:
return []
if semver.match(tf.__version__, '<1.0.0'):
init_flag = sess.run(
tf.pack([tf.is_variable_initialized(v) for v in variables]))
else:
init_flag = sess.run(
tf.stack([tf.is_variable_initialized(v) for v in variables]))
return [v for v, f in zip(variables, init_flag) if not f]
# Tears of the debugging...