def begin(self):
variables = tf.contrib.framework.get_variables(scope=self.params["prefix"])
def varname_in_checkpoint(name):
"""Removes the prefix from the variable name.
"""
prefix_parts = self.params["prefix"].split("/")
checkpoint_prefix = "/".join(prefix_parts[:-1])
return name.replace(checkpoint_prefix + "/", "")
target_names = [varname_in_checkpoint(_.op.name) for _ in variables]
restore_map = {k: v for k, v in zip(target_names, variables)}
tf.logging.info("Restoring variables: \n%s",
yaml.dump({k: v.op.name
for k, v in restore_map.items()}))
self._saver = tf.train.Saver(restore_map)
评论列表
文章目录