def uninitialized_variables(session, var_list=None):
"""Gets the list of uninitialized variables.
Note: this has to be evaluated on a session.
Parameters
----------
session: tf.Session
The TensorFlow session to scan for uninitialized variables
var_list: list(tf.Varaible) or None
The list of variables to filter for uninitialized ones.
Defaults to tf.all_variables() is used.
"""
if var_list is None:
var_list = tf.all_variables()
reported_var_names = session.run(tf.report_uninitialized_variables(var_list))
uninit_vars = []
for name in reported_var_names:
try:
uninit_vars.append(tf.get_variable(name))
except ValueError:
print("Failed to collect variable {}. Skipping.", name)
return uninit_vars
评论列表
文章目录