def save(network, sess, filename=None):
"""Save the variables contained by a network to disk."""
to_save = collections.defaultdict(dict)
variables = nn.get_variables_in_module(network)
for v in variables:
split = v.name.split(":")[0].split("/")
module_name = split[-2]
variable_name = split[-1]
to_save[module_name][variable_name] = v.eval(sess)
if filename:
with open(filename, "wb") as f:
pickle.dump(to_save, f)
return to_save
评论列表
文章目录