def initialize_variables(self, save_file=None):
self.session.run(tf.global_variables_initializer())
if save_file is not None:
try:
self.saver.restore(self.session, save_file)
except:
# some wizardry here... basically, only restore variables
# that are in the save file; otherwise, initialize them normally.
from tensorflow.python.framework import meta_graph
meta_graph_def = meta_graph.read_meta_graph_file(save_file + '.meta')
stored_var_names = set([n.name
for n in meta_graph_def.graph_def.node
if n.op == 'VariableV2'])
print(stored_var_names)
var_list = [v for v in tf.global_variables()
if v.op.name in stored_var_names]
# initialize all of the variables
self.session.run(tf.global_variables_initializer())
# then overwrite the ones we have in the save file
# by using a throwaway saver, saved models are automatically
# "upgraded" to the latest graph definition.
throwaway_saver = tf.train.Saver(var_list=var_list)
throwaway_saver.restore(self.session, save_file)
评论列表
文章目录