def get_uninitialized_variables(variables=None):
"""Return a list of uninitialized tf variables.
Parameters
----------
variables: tf.Variable, list(tf.Variable), optional
Filter variable list to only those that are uninitialized. If no
variables are specified the list of all variables in the graph
will be used.
Returns
-------
list(tf.Variable)
List of uninitialized tf variables.
"""
sess = tf.get_default_session()
if variables is None:
variables = tf.global_variables()
else:
variables = list(variables)
if len(variables) == 0:
return []
if semver.match(tf.__version__, '<1.0.0'):
init_flag = sess.run(
tf.pack([tf.is_variable_initialized(v) for v in variables]))
else:
init_flag = sess.run(
tf.stack([tf.is_variable_initialized(v) for v in variables]))
return [v for v, f in zip(variables, init_flag) if not f]
# Tears of the debugging...
评论列表
文章目录