def load_trainable_vars(sess,filename):
"""load a .npz archive and assign the value of each loaded
ndarray to the trainable variable whose name matches the
archive key. Any elements in the archive that do not have
a corresponding trainable variable will be returned in a dict.
"""
other={}
try:
tv=dict([ (str(v.name),v) for v in tf.trainable_variables() ])
for k,d in np.load(filename).items():
if k in tv:
print('restoring ' + k)
sess.run(tf.assign( tv[k], d) )
else:
other[k] = d
except IOError:
pass
return other
评论列表
文章目录